Skip to main content

tnc/tensornetwork/
tensor.rs

1use std::iter::zip;
2use std::num::TryFromIntError;
3use std::ops::{BitAnd, BitOr, BitXor, BitXorAssign, Sub};
4
5use approx::AbsDiffEq;
6use rustc_hash::FxHashMap;
7use serde::{Deserialize, Serialize};
8
9use crate::tensornetwork::tensordata::TensorData;
10use crate::utils::datastructures::UnionFind;
11
12/// Unique index of a leg.
13pub type EdgeIndex = usize;
14
15/// Index of a tensor in a tensor network.
16pub type TensorIndex = usize;
17
18/// Abstract representation of a tensor. Can be a *composite* tensor (that is, a
19/// tensor network) or a *leaf* tensor.
20#[derive(Default, Debug, Clone, Serialize, Deserialize)]
21pub struct Tensor {
22    /// The inner tensors that make up this tensor. If non-empty, this tensor is
23    /// called a *composite* tensor.
24    pub(crate) tensors: Vec<Tensor>,
25
26    /// The legs of the tensor. Each leg should have a unique id. Connected tensors
27    /// are recognized by having at least one leg id in common.
28    pub(crate) legs: Vec<EdgeIndex>,
29
30    /// The bond dimensions of the legs, same length and order as `legs`. It is
31    /// assumed (but not checked!) that the bond dimensions of different tensors that
32    /// connect to the same leg match.
33    pub(crate) bond_dims: Vec<u64>,
34
35    /// The data of the tensor.
36    pub(crate) tensordata: TensorData,
37}
38
39impl Tensor {
40    /// Constructs a Tensor object with the given `legs` (edge ids) and corresponding
41    /// `bond_dims`. The tensor doesn't have underlying data.
42    #[inline]
43    pub(crate) fn new(legs: Vec<EdgeIndex>, bond_dims: Vec<u64>) -> Self {
44        assert_eq!(legs.len(), bond_dims.len());
45        Self {
46            legs,
47            bond_dims,
48            ..Default::default()
49        }
50    }
51
52    /// Constructs a Tensor using with the given edge ids and a mapping of edge ids
53    /// to corresponding bond dimension.
54    ///
55    /// # Examples
56    /// ```
57    /// # use tnc::tensornetwork::tensor::Tensor;
58    /// # use rustc_hash::FxHashMap;
59    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
60    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
61    /// assert_eq!(tensor.legs(), &[1, 2, 3]);
62    /// assert_eq!(tensor.bond_dims(), &[2, 4, 6]);
63    /// ```
64    #[inline]
65    pub fn new_from_map(legs: Vec<EdgeIndex>, bond_dims_map: &FxHashMap<EdgeIndex, u64>) -> Self {
66        let bond_dims = legs.iter().map(|l| bond_dims_map[l]).collect();
67        Self::new(legs, bond_dims)
68    }
69
70    /// Constructs a Tensor with the given edge ids and the same bond dimension for
71    /// all edges.
72    ///
73    /// # Examples
74    /// ```
75    /// # use tnc::tensornetwork::tensor::Tensor;
76    /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 2);
77    /// assert_eq!(tensor.legs(), &[1, 2, 3]);
78    /// assert_eq!(tensor.bond_dims(), &[2, 2, 2]);
79    /// ```
80    #[inline]
81    pub fn new_from_const(legs: Vec<EdgeIndex>, bond_dim: u64) -> Self {
82        let bond_dims = vec![bond_dim; legs.len()];
83        Self::new(legs, bond_dims)
84    }
85
86    /// Creates a new composite tensor with the given nested tensors.
87    #[inline]
88    pub fn new_composite(tensors: Vec<Self>) -> Self {
89        Self {
90            tensors,
91            ..Default::default()
92        }
93    }
94
95    /// Returns edge ids of Tensor object.
96    ///
97    /// # Examples
98    /// ```
99    /// # use tnc::tensornetwork::tensor::Tensor;
100    /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 3);
101    /// assert_eq!(tensor.legs(), &[1, 2, 3]);
102    /// ```
103    #[inline]
104    pub fn legs(&self) -> &Vec<EdgeIndex> {
105        &self.legs
106    }
107
108    /// Internal method to set legs. Needs pub(crate) for contraction order finding for hierarchies.
109    #[inline]
110    pub(crate) fn set_legs(&mut self, legs: Vec<EdgeIndex>) {
111        self.legs = legs;
112    }
113
114    /// Returns an iterator of tuples of leg ids and their corresponding bond size.
115    #[inline]
116    pub fn edges(&self) -> impl Iterator<Item = (&EdgeIndex, &u64)> + '_ {
117        std::iter::zip(&self.legs, &self.bond_dims)
118    }
119
120    /// Returns the nested tensors of a composite tensor.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// # use tnc::tensornetwork::tensor::Tensor;
126    /// # use rustc_hash::FxHashMap;
127    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
128    /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
129    /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
130    /// let tn = Tensor::new_composite(vec![v1.clone(), v2.clone()]);
131    /// for (tensor, ref_tensor) in std::iter::zip(tn.tensors(), vec![v1, v2]){
132    ///    assert_eq!(tensor.legs(), ref_tensor.legs());
133    /// }
134    /// ```
135    #[inline]
136    pub fn tensors(&self) -> &Vec<Self> {
137        &self.tensors
138    }
139
140    /// Gets a nested `Tensor` based on the `nested_indices` which specify the index
141    /// of the tensor at each level of the hierarchy.
142    ///
143    /// # Examples
144    /// ```
145    /// # use tnc::tensornetwork::tensor::Tensor;
146    /// # use rustc_hash::FxHashMap;
147    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 2), (4, 1)]);
148    /// let mut v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
149    /// let mut v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
150    /// let mut v3 = Tensor::new_from_map(vec![2, 3], &bond_dims);
151    /// let mut v4 = Tensor::new_from_map(vec![3, 4], &bond_dims);
152    /// let tn1 = Tensor::new_composite(vec![v1, v2]);
153    /// let tn2 = Tensor::new_composite(vec![v3.clone(), v4]);
154    /// let nested_tn = Tensor::new_composite(vec![tn1, tn2]);
155    ///
156    /// assert_eq!(nested_tn.nested_tensor(&[1, 0]).legs(), v3.legs());
157    /// ```
158    pub fn nested_tensor(&self, nested_indices: &[usize]) -> &Tensor {
159        let mut tensor = self;
160        for index in nested_indices {
161            tensor = tensor.tensor(*index);
162        }
163        tensor
164    }
165
166    /// Returns the total number of leaf tensors in the hierarchy.
167    pub fn total_num_tensors(&self) -> usize {
168        if self.is_composite() {
169            self.tensors.iter().map(Self::total_num_tensors).sum()
170        } else {
171            1
172        }
173    }
174
175    /// Get ith Tensor.
176    ///
177    /// # Examples
178    ///
179    /// ```
180    /// # use tnc::tensornetwork::tensor::Tensor;
181    /// # use rustc_hash::FxHashMap;
182    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
183    /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
184    /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
185    /// let tn = Tensor::new_composite(vec![v1.clone(), v2]);
186    /// assert_eq!(tn.tensor(0).legs(), v1.legs());
187    /// ```
188    #[inline]
189    pub fn tensor(&self, i: TensorIndex) -> &Self {
190        &self.tensors[i]
191    }
192
193    /// Getter for bond dimensions.
194    #[inline]
195    pub fn bond_dims(&self) -> &Vec<u64> {
196        assert!(self.is_leaf());
197        &self.bond_dims
198    }
199
200    /// Returns the shape of tensor. This is the same as the bond dimensions, but as
201    /// `usize`. The conversion can fail, hence a [`Result`] is returned.
202    pub fn shape(&self) -> Result<Vec<usize>, TryFromIntError> {
203        self.bond_dims.iter().map(|&dim| dim.try_into()).collect()
204    }
205
206    /// Returns the number of dimensions.
207    ///
208    /// # Examples
209    /// ```
210    /// # use tnc::tensornetwork::tensor::Tensor;
211    /// # use rustc_hash::FxHashMap;
212    /// let bond_dims = FxHashMap::from_iter([(1, 4), (2, 6), (3, 2)]);
213    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
214    /// assert_eq!(tensor.dims(), 3);
215    /// ```
216    #[inline]
217    pub fn dims(&self) -> usize {
218        self.legs.len()
219    }
220
221    /// Returns the number of elements. This is a f64 to avoid overflow in large
222    /// tensors.
223    ///
224    /// # Examples
225    /// ```
226    /// # use tnc::tensornetwork::tensor::Tensor;
227    /// # use rustc_hash::FxHashMap;
228    /// let bond_dims = FxHashMap::from_iter([(1, 5), (2, 15), (3, 8)]);
229    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
230    /// assert_eq!(tensor.size(), 600.0);
231    /// ```
232    #[inline]
233    pub fn size(&self) -> f64 {
234        self.bond_dims.iter().map(|v| *v as f64).product()
235    }
236
237    /// Returns true if Tensor is a leaf tensor, without any nested tensors.
238    ///
239    /// # Examples
240    /// ```
241    /// # use tnc::tensornetwork::tensor::Tensor;
242    /// # use rustc_hash::FxHashMap;
243    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
244    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
245    /// assert_eq!(tensor.is_leaf(), true);
246    /// let comp = Tensor::new_composite(vec![tensor]);
247    /// assert_eq!(comp.is_leaf(), false);
248    /// ```
249    #[inline]
250    pub fn is_leaf(&self) -> bool {
251        self.tensors.is_empty()
252    }
253
254    /// Returns true if Tensor is composite.
255    ///
256    /// # Examples
257    /// ```
258    /// # use tnc::tensornetwork::tensor::Tensor;
259    /// # use rustc_hash::FxHashMap;
260    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
261    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
262    /// assert_eq!(tensor.is_composite(), false);
263    /// let comp = Tensor::new_composite(vec![tensor]);
264    /// assert_eq!(comp.is_composite(), true);
265    /// ```
266    #[inline]
267    pub fn is_composite(&self) -> bool {
268        !self.tensors.is_empty()
269    }
270
271    /// Returns true if Tensor is empty. This means, it doesn't have any subtensors,
272    /// has no legs and is doesn't have any data (e.g., is not a scalar).
273    ///
274    /// # Examples
275    /// ```
276    /// # use tnc::tensornetwork::tensor::Tensor;
277    /// let tensor = Tensor::default();
278    /// assert_eq!(tensor.is_empty(), true);
279    /// ```
280    #[inline]
281    pub fn is_empty(&self) -> bool {
282        self.tensors.is_empty()
283            && self.legs.is_empty()
284            && matches!(*self.tensor_data(), TensorData::Uncontracted)
285    }
286
287    /// Consumes the `Tensor` and returns its `TensorData`.
288    ///
289    /// # Examples
290    /// ```
291    /// # use tnc::tensornetwork::tensor::Tensor;
292    /// # use tnc::tensornetwork::tensordata::TensorData;
293    /// let tensor = Tensor::new_from_const(vec![1, 2], 2);
294    /// let tensordata = tensor.into_tensor_data();
295    /// assert_eq!(tensordata, TensorData::Uncontracted);
296    /// ```
297    #[inline]
298    pub fn into_tensor_data(self) -> TensorData {
299        self.tensordata
300    }
301
302    /// Pushes additional `tensor` into this tensor, which must be a composite tensor.
303    #[inline]
304    pub fn push_tensor(&mut self, tensor: Self) {
305        assert!(
306            self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
307            "Cannot push tensors into a leaf tensor"
308        );
309        self.tensors.push(tensor);
310    }
311
312    /// Pushes additional `tensors` into this tensor, which must be a composite tensor.
313    #[inline]
314    pub fn push_tensors(&mut self, mut tensors: Vec<Self>) {
315        assert!(
316            self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
317            "Cannot push tensors into a leaf tensor"
318        );
319        self.tensors.append(&mut tensors);
320    }
321
322    /// Getter for tensor data.
323    #[inline]
324    pub fn tensor_data(&self) -> &TensorData {
325        &self.tensordata
326    }
327
328    /// Setter for tensor data.
329    ///
330    /// # Examples
331    ///
332    /// ```
333    /// # use tnc::tensornetwork::tensor::Tensor;
334    /// # use tnc::tensornetwork::tensordata::TensorData;
335    /// let mut tensor = Tensor::new_from_const(vec![0, 1], 2);
336    /// let tensordata = TensorData::Gate((String::from("x"), vec![], false));
337    /// tensor.set_tensor_data(tensordata);
338    /// ```
339    #[inline]
340    pub fn set_tensor_data(&mut self, tensordata: TensorData) {
341        assert!(
342            self.is_leaf() || matches!(tensordata, TensorData::Uncontracted),
343            "Cannot add data to composite tensor"
344        );
345        self.tensordata = tensordata;
346    }
347
348    /// Returns whether all tensors inside this tensor are connected.
349    /// This only checks the top-level, not recursing into composite tensors.
350    ///
351    /// # Examples
352    /// ```
353    /// # use tnc::tensornetwork::tensor::Tensor;
354    /// # use rustc_hash::FxHashMap;
355    /// // Create a tensor network with two connected tensors
356    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 5)]);
357    /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
358    /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
359    /// let mut tn = Tensor::new_composite(vec![v1, v2]);
360    /// assert!(tn.is_connected());
361    ///
362    /// // Introduce a new tensor that is not connected
363    /// let v3 = Tensor::new_from_map(vec![3], &bond_dims);
364    /// tn.push_tensor(v3);
365    /// assert!(!tn.is_connected());
366    /// ```
367    pub fn is_connected(&self) -> bool {
368        let num_tensors = self.tensors.len();
369        let mut uf = UnionFind::new(num_tensors);
370
371        for t1_id in 0..num_tensors {
372            for t2_id in (t1_id + 1)..num_tensors {
373                let t1 = &self.tensors[t1_id];
374                let t2 = &self.tensors[t2_id];
375                if !(t1 & t2).legs.is_empty() {
376                    uf.union(t1_id, t2_id);
377                }
378            }
379        }
380
381        uf.count_sets() == 1
382    }
383
384    /// Returns `Tensor` with legs in `self` that are not in `other`.
385    ///
386    /// # Examples
387    /// ```
388    /// # use tnc::tensornetwork::tensor::Tensor;
389    /// # use rustc_hash::FxHashMap;
390    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
391    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
392    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
393    /// let diff_tensor = &tensor1 - &tensor2;
394    /// assert_eq!(diff_tensor.legs(), &[1, 3]);
395    /// assert_eq!(diff_tensor.bond_dims(), &[2, 6]);
396    /// ```
397    #[must_use]
398    pub fn difference(&self, other: &Self) -> Self {
399        let mut new_legs = Vec::with_capacity(self.legs.len());
400        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
401        for (leg, dim) in self.edges() {
402            if !other.legs.contains(leg) {
403                new_legs.push(*leg);
404                new_bond_dims.push(*dim);
405            }
406        }
407        Self::new(new_legs, new_bond_dims)
408    }
409
410    /// Returns `Tensor` with union of legs in both `self` and `other`.
411    ///
412    /// # Examples
413    /// ```
414    /// # use tnc::tensornetwork::tensor::Tensor;
415    /// # use rustc_hash::FxHashMap;
416    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
417    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
418    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
419    /// let union_tensor = &tensor1 | &tensor2;
420    /// assert_eq!(union_tensor.legs(), &[1, 2, 3, 4, 5]);
421    /// assert_eq!(union_tensor.bond_dims(), &[2, 4, 6, 3, 9]);
422    /// ```
423    #[must_use]
424    pub fn union(&self, other: &Self) -> Self {
425        let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
426        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
427        new_legs.extend_from_slice(&self.legs);
428        new_bond_dims.extend_from_slice(&self.bond_dims);
429        for (leg, dim) in other.edges() {
430            if !self.legs.contains(leg) {
431                new_legs.push(*leg);
432                new_bond_dims.push(*dim);
433            }
434        }
435        Self::new(new_legs, new_bond_dims)
436    }
437
438    /// Returns `Tensor` with intersection of legs in `self` and `other`.
439    ///
440    /// # Examples
441    /// ```
442    /// # use tnc::tensornetwork::tensor::Tensor;
443    /// # use rustc_hash::FxHashMap;
444    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
445    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
446    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
447    /// let intersection_tensor = &tensor1 & &tensor2;
448    /// assert_eq!(intersection_tensor.legs(), &[2]);
449    /// assert_eq!(intersection_tensor.bond_dims(), &[4]);
450    /// ```
451    #[must_use]
452    pub fn intersection(&self, other: &Self) -> Self {
453        let mut new_legs = Vec::with_capacity(self.legs.len().min(other.legs.len()));
454        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
455        for (leg, dim) in self.edges() {
456            if other.legs.contains(leg) {
457                new_legs.push(*leg);
458                new_bond_dims.push(*dim);
459            }
460        }
461        Self::new(new_legs, new_bond_dims)
462    }
463
464    /// Returns `Tensor` with symmetrical difference of legs in `self` and `other`.
465    ///
466    /// # Examples
467    /// ```
468    /// # use tnc::tensornetwork::tensor::Tensor;
469    /// # use rustc_hash::FxHashMap;
470    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
471    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
472    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
473    /// let sym_dif_tensor = &tensor1 ^ &tensor2;
474    /// assert_eq!(sym_dif_tensor.legs(), &[1, 3, 4, 5]);
475    /// assert_eq!(sym_dif_tensor.bond_dims(), &[2, 6, 3, 9]);
476    /// ```
477    #[must_use]
478    pub fn symmetric_difference(&self, other: &Self) -> Self {
479        let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
480        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
481        for (leg, dim) in self.edges() {
482            if !other.legs.contains(leg) {
483                new_legs.push(*leg);
484                new_bond_dims.push(*dim);
485            }
486        }
487        for (leg, dim) in other.edges() {
488            if !self.legs.contains(leg) {
489                new_legs.push(*leg);
490                new_bond_dims.push(*dim);
491            }
492        }
493        Self::new(new_legs, new_bond_dims)
494    }
495
496    /// Get output legs after tensor contraction
497    pub fn external_tensor(&self) -> Tensor {
498        if self.is_leaf() {
499            return self.clone();
500        }
501
502        let mut ext_tensor = Self::default();
503        for tensor in &self.tensors {
504            let new_tensor = if tensor.is_composite() {
505                &tensor.external_tensor()
506            } else {
507                tensor
508            };
509            ext_tensor = &ext_tensor ^ new_tensor;
510        }
511
512        ext_tensor
513    }
514}
515
516impl PartialEq for Tensor {
517    fn eq(&self, other: &Self) -> bool {
518        self.legs == other.legs
519            && self.bond_dims == other.bond_dims
520            && self.tensors == other.tensors
521            && self.tensordata == other.tensordata
522    }
523}
524
525impl AbsDiffEq for Tensor {
526    type Epsilon = f64;
527
528    fn default_epsilon() -> Self::Epsilon {
529        f64::default_epsilon()
530    }
531
532    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
533        if self.legs != other.legs {
534            return false;
535        }
536        if self.bond_dims != other.bond_dims {
537            return false;
538        }
539        if self.tensors.len() != other.tensors.len() {
540            return false;
541        }
542        for (tensor, other_tensor) in zip(&self.tensors, &other.tensors) {
543            if !Tensor::abs_diff_eq(tensor, other_tensor, epsilon) {
544                return false;
545            }
546        }
547
548        TensorData::abs_diff_eq(&self.tensordata, &other.tensordata, epsilon)
549    }
550}
551
552impl BitOr for &Tensor {
553    type Output = Tensor;
554    #[inline]
555    fn bitor(self, rhs: &Tensor) -> Tensor {
556        self.union(rhs)
557    }
558}
559
560impl BitAnd for &Tensor {
561    type Output = Tensor;
562    #[inline]
563    fn bitand(self, rhs: &Tensor) -> Tensor {
564        self.intersection(rhs)
565    }
566}
567
568impl BitXor for &Tensor {
569    type Output = Tensor;
570    #[inline]
571    fn bitxor(self, rhs: &Tensor) -> Tensor {
572        self.symmetric_difference(rhs)
573    }
574}
575
576impl Sub for &Tensor {
577    type Output = Tensor;
578    #[inline]
579    fn sub(self, rhs: &Tensor) -> Tensor {
580        self.difference(rhs)
581    }
582}
583
584impl BitXorAssign<&Tensor> for Tensor {
585    #[inline]
586    fn bitxor_assign(&mut self, rhs: &Tensor) {
587        *self = self.symmetric_difference(rhs);
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    use std::iter::zip;
596
597    use approx::assert_abs_diff_eq;
598    use rustc_hash::FxHashMap;
599
600    use crate::tensornetwork::tensordata::TensorData;
601
602    macro_rules! assert_matches {
603        ($left:expr, $pattern:pat) => {
604            match $left {
605                $pattern => (),
606                _ => panic!(
607                    "Expected pattern {} but got {:?}",
608                    stringify!($pattern),
609                    $left
610                ),
611            }
612        };
613    }
614
615    #[test]
616    fn test_empty_tensor() {
617        let tensor = Tensor::default();
618        assert!(tensor.tensors.is_empty());
619        assert!(tensor.legs.is_empty());
620        assert!(tensor.bond_dims.is_empty());
621        assert!(tensor.is_empty());
622    }
623
624    #[test]
625    fn test_new() {
626        let tensor = Tensor::new(vec![2, 4, 5], vec![4, 2, 6]);
627        assert_eq!(tensor.legs(), &[2, 4, 5]);
628        assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
629        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
630    }
631
632    #[test]
633    fn test_new_from_map() {
634        let bond_dims = FxHashMap::from_iter([(1, 1), (2, 4), (3, 7), (4, 2), (5, 6)]);
635        let tensor = Tensor::new_from_map(vec![2, 4, 5], &bond_dims);
636        assert_eq!(tensor.legs(), &[2, 4, 5]);
637        assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
638        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
639    }
640
641    #[test]
642    fn test_new_from_const() {
643        let tensor = Tensor::new_from_const(vec![9, 2, 5, 1], 3);
644        assert_eq!(tensor.legs(), &[9, 2, 5, 1]);
645        assert_eq!(tensor.bond_dims(), &[3, 3, 3, 3]);
646        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
647    }
648
649    #[test]
650    fn test_external_tensor() {
651        let bond_dims = FxHashMap::from_iter([
652            (2, 2),
653            (3, 4),
654            (4, 6),
655            (5, 8),
656            (6, 10),
657            (7, 12),
658            (8, 14),
659            (9, 16),
660        ]);
661        let tensor_1 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
662        let tensor_2 = Tensor::new_from_map(vec![2, 3, 5], &bond_dims);
663        let tensor_12 = Tensor::new_composite(vec![tensor_1, tensor_2]);
664
665        let tensor_3 = Tensor::new_from_map(vec![6, 7, 8], &bond_dims);
666        let tensor_4 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
667        let tensor_34 = Tensor::new_composite(vec![tensor_3, tensor_4]);
668
669        let tensor_1234 = Tensor::new_composite(vec![tensor_12, tensor_34]);
670
671        let external = tensor_1234.external_tensor();
672
673        let ref_tensor = Tensor::new_from_map(vec![4, 5, 7, 9], &bond_dims);
674        assert_abs_diff_eq!(&external, &ref_tensor);
675    }
676
677    #[test]
678    fn test_push_tensor() {
679        let bond_dims =
680            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
681        let ref_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
682        let ref_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
683
684        let mut tensor = Tensor::default();
685
686        // Push tensor 1
687        let tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
688        tensor.push_tensor(tensor_1);
689
690        for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1]) {
691            assert_abs_diff_eq!(sub_tensor, ref_tensor);
692        }
693
694        // Push tensor 2
695        let tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
696        tensor.push_tensor(tensor_2);
697
698        for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1, &ref_tensor_2]) {
699            assert_abs_diff_eq!(sub_tensor, ref_tensor);
700        }
701
702        // Test that other fields are unchanged
703        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
704        assert!(tensor.legs().is_empty());
705    }
706
707    #[test]
708    #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
709    fn test_push_tensor_to_leaf() {
710        let bond_dims =
711            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
712        let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
713        let pushed_tensor = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
714        leaf_tensor.push_tensor(pushed_tensor);
715    }
716
717    #[test]
718    fn test_push_tensors() {
719        let bond_dims =
720            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
721        let ref_tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
722        let ref_tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
723        let ref_tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
724
725        let tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
726        let tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
727        let tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
728        let mut tensor = Tensor::default();
729        tensor.push_tensors(vec![tensor_1, tensor_2, tensor_3]);
730
731        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
732
733        for (sub_tensor, ref_tensor) in zip(
734            tensor.tensors(),
735            &vec![ref_tensor_1, ref_tensor_2, ref_tensor_3],
736        ) {
737            assert_abs_diff_eq!(sub_tensor, ref_tensor);
738        }
739    }
740
741    #[test]
742    #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
743    fn test_push_tensors_to_leaf() {
744        let bond_dims =
745            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
746        let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
747        let pushed_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
748        let pushed_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
749
750        leaf_tensor.push_tensors(vec![pushed_tensor_1, pushed_tensor_2]);
751    }
752}