tnc/tensornetwork/
tensor.rs

1use std::hash::{Hash, Hasher};
2use std::iter::zip;
3use std::num::TryFromIntError;
4use std::ops::{BitAnd, BitOr, BitXor, BitXorAssign, Sub};
5
6use float_cmp::{ApproxEq, F64Margin};
7use rustc_hash::FxHashMap;
8use serde::{Deserialize, Serialize};
9
10use crate::tensornetwork::tensordata::TensorData;
11use crate::utils::datastructures::UnionFind;
12
13/// Unique index of a leg.
14pub type EdgeIndex = usize;
15
16/// Index of a tensor in a tensor network.
17pub type TensorIndex = usize;
18
19/// Abstract representation of a tensor. Can be a *composite* tensor (that is, a
20/// tensor network) or a *leaf* tensor.
21#[derive(Default, Debug, Clone, Serialize, Deserialize)]
22pub struct Tensor {
23    /// The inner tensors that make up this tensor. If non-empty, this tensor is
24    /// called a *composite* tensor.
25    pub(crate) tensors: Vec<Tensor>,
26
27    /// The legs of the tensor. Each leg should have a unique id. Connected tensors
28    /// are recognized by having at least one leg id in common.
29    pub(crate) legs: Vec<EdgeIndex>,
30
31    /// The bond dimensions of the legs, same length and order as `legs`. It is
32    /// assumed (but not checked!) that the bond dimensions of different tensors that
33    /// connect to the same leg match.
34    pub(crate) bond_dims: Vec<u64>,
35
36    /// The data of the tensor.
37    pub(crate) tensordata: TensorData,
38}
39
40impl Hash for Tensor {
41    fn hash<H: Hasher>(&self, state: &mut H) {
42        self.legs.hash(state);
43    }
44}
45
46impl Tensor {
47    /// Constructs a Tensor object with the given `legs` (edge ids) and corresponding
48    /// `bond_dims`. The tensor doesn't have underlying data.
49    #[inline]
50    pub(crate) fn new(legs: Vec<EdgeIndex>, bond_dims: Vec<u64>) -> Self {
51        assert_eq!(legs.len(), bond_dims.len());
52        Self {
53            legs,
54            bond_dims,
55            ..Default::default()
56        }
57    }
58
59    /// Constructs a Tensor using with the given edge ids and a mapping of edge ids
60    /// to corresponding bond dimension.
61    ///
62    /// # Examples
63    /// ```
64    /// # use tnc::tensornetwork::tensor::Tensor;
65    /// # use rustc_hash::FxHashMap;
66    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
67    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
68    /// assert_eq!(tensor.legs(), &[1, 2, 3]);
69    /// assert_eq!(tensor.bond_dims(), &[2, 4, 6]);
70    /// ```
71    #[inline]
72    pub fn new_from_map(legs: Vec<EdgeIndex>, bond_dims_map: &FxHashMap<EdgeIndex, u64>) -> Self {
73        let bond_dims = legs.iter().map(|l| bond_dims_map[l]).collect();
74        Self::new(legs, bond_dims)
75    }
76
77    /// Constructs a Tensor with the given edge ids and the same bond dimension for
78    /// all edges.
79    ///
80    /// # Examples
81    /// ```
82    /// # use tnc::tensornetwork::tensor::Tensor;
83    /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 2);
84    /// assert_eq!(tensor.legs(), &[1, 2, 3]);
85    /// assert_eq!(tensor.bond_dims(), &[2, 2, 2]);
86    /// ```
87    #[inline]
88    pub fn new_from_const(legs: Vec<EdgeIndex>, bond_dim: u64) -> Self {
89        let bond_dims = vec![bond_dim; legs.len()];
90        Self::new(legs, bond_dims)
91    }
92
93    /// Creates a new composite tensor with the given nested tensors.
94    #[inline]
95    pub fn new_composite(tensors: Vec<Self>) -> Self {
96        Self {
97            tensors,
98            ..Default::default()
99        }
100    }
101
102    /// Returns edge ids of Tensor object.
103    ///
104    /// # Examples
105    /// ```
106    /// # use tnc::tensornetwork::tensor::Tensor;
107    /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 3);
108    /// assert_eq!(tensor.legs(), &[1, 2, 3]);
109    /// ```
110    #[inline]
111    pub fn legs(&self) -> &Vec<EdgeIndex> {
112        &self.legs
113    }
114
115    /// Internal method to set legs. Needs pub(crate) for contraction order finding for hierarchies.
116    #[inline]
117    pub(crate) fn set_legs(&mut self, legs: Vec<EdgeIndex>) {
118        self.legs = legs;
119    }
120
121    /// Returns an iterator of tuples of leg ids and their corresponding bond size.
122    #[inline]
123    pub fn edges(&self) -> impl Iterator<Item = (&EdgeIndex, &u64)> + '_ {
124        std::iter::zip(&self.legs, &self.bond_dims)
125    }
126
127    /// Returns the nested tensors of a composite tensor.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// # use tnc::tensornetwork::tensor::Tensor;
133    /// # use rustc_hash::FxHashMap;
134    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
135    /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
136    /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
137    /// let tn = Tensor::new_composite(vec![v1.clone(), v2.clone()]);
138    /// for (tensor, ref_tensor) in std::iter::zip(tn.tensors(), vec![v1, v2]){
139    ///    assert_eq!(tensor.legs(), ref_tensor.legs());
140    /// }
141    /// ```
142    #[inline]
143    pub fn tensors(&self) -> &Vec<Self> {
144        &self.tensors
145    }
146
147    /// Gets a nested `Tensor` based on the `nested_indices` which specify the index
148    /// of the tensor at each level of the hierarchy.
149    ///
150    /// # Examples
151    /// ```
152    /// # use tnc::tensornetwork::tensor::Tensor;
153    /// # use rustc_hash::FxHashMap;
154    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 2), (4, 1)]);
155    /// let mut v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
156    /// let mut v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
157    /// let mut v3 = Tensor::new_from_map(vec![2, 3], &bond_dims);
158    /// let mut v4 = Tensor::new_from_map(vec![3, 4], &bond_dims);
159    /// let tn1 = Tensor::new_composite(vec![v1, v2]);
160    /// let tn2 = Tensor::new_composite(vec![v3.clone(), v4]);
161    /// let nested_tn = Tensor::new_composite(vec![tn1, tn2]);
162    ///
163    /// assert_eq!(nested_tn.nested_tensor(&[1, 0]).legs(), v3.legs());
164    /// ```
165    pub fn nested_tensor(&self, nested_indices: &[usize]) -> &Tensor {
166        let mut tensor = self;
167        for index in nested_indices {
168            tensor = tensor.tensor(*index);
169        }
170        tensor
171    }
172
173    /// Returns the total number of leaf tensors in the hierarchy.
174    pub fn total_num_tensors(&self) -> usize {
175        if self.is_composite() {
176            self.tensors.iter().map(Self::total_num_tensors).sum()
177        } else {
178            1
179        }
180    }
181
182    /// Get ith Tensor.
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// # use tnc::tensornetwork::tensor::Tensor;
188    /// # use rustc_hash::FxHashMap;
189    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
190    /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
191    /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
192    /// let tn = Tensor::new_composite(vec![v1.clone(), v2]);
193    /// assert_eq!(tn.tensor(0).legs(), v1.legs());
194    /// ```
195    #[inline]
196    pub fn tensor(&self, i: TensorIndex) -> &Self {
197        &self.tensors[i]
198    }
199
200    /// Getter for bond dimensions.
201    #[inline]
202    pub fn bond_dims(&self) -> &Vec<u64> {
203        assert!(self.is_leaf());
204        &self.bond_dims
205    }
206
207    /// Returns the shape of tensor. This is the same as the bond dimensions, but as
208    /// `usize`. The conversion can fail, hence a [`Result`] is returned.
209    pub fn shape(&self) -> Result<Vec<usize>, TryFromIntError> {
210        self.bond_dims.iter().map(|&dim| dim.try_into()).collect()
211    }
212
213    /// Returns the number of dimensions.
214    ///
215    /// # Examples
216    /// ```
217    /// # use tnc::tensornetwork::tensor::Tensor;
218    /// # use rustc_hash::FxHashMap;
219    /// let bond_dims = FxHashMap::from_iter([(1, 4), (2, 6), (3, 2)]);
220    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
221    /// assert_eq!(tensor.dims(), 3);
222    /// ```
223    #[inline]
224    pub fn dims(&self) -> usize {
225        self.legs.len()
226    }
227
228    /// Returns the number of elements. This is a f64 to avoid overflow in large
229    /// tensors.
230    ///
231    /// # Examples
232    /// ```
233    /// # use tnc::tensornetwork::tensor::Tensor;
234    /// # use rustc_hash::FxHashMap;
235    /// let bond_dims = FxHashMap::from_iter([(1, 5), (2, 15), (3, 8)]);
236    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
237    /// assert_eq!(tensor.size(), 600.0);
238    /// ```
239    #[inline]
240    pub fn size(&self) -> f64 {
241        self.bond_dims.iter().map(|v| *v as f64).product()
242    }
243
244    /// Returns true if Tensor is a leaf tensor, without any nested tensors.
245    ///
246    /// # Examples
247    /// ```
248    /// # use tnc::tensornetwork::tensor::Tensor;
249    /// # use rustc_hash::FxHashMap;
250    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
251    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
252    /// assert_eq!(tensor.is_leaf(), true);
253    /// let comp = Tensor::new_composite(vec![tensor]);
254    /// assert_eq!(comp.is_leaf(), false);
255    /// ```
256    #[inline]
257    pub fn is_leaf(&self) -> bool {
258        self.tensors.is_empty()
259    }
260
261    /// Returns true if Tensor is composite.
262    ///
263    /// # Examples
264    /// ```
265    /// # use tnc::tensornetwork::tensor::Tensor;
266    /// # use rustc_hash::FxHashMap;
267    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
268    /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
269    /// assert_eq!(tensor.is_composite(), false);
270    /// let comp = Tensor::new_composite(vec![tensor]);
271    /// assert_eq!(comp.is_composite(), true);
272    /// ```
273    #[inline]
274    pub fn is_composite(&self) -> bool {
275        !self.tensors.is_empty()
276    }
277
278    /// Returns true if Tensor is empty. This means, it doesn't have any subtensors,
279    /// has no legs and is doesn't have any data (e.g., is not a scalar).
280    ///
281    /// # Examples
282    /// ```
283    /// # use tnc::tensornetwork::tensor::Tensor;
284    /// let tensor = Tensor::default();
285    /// assert_eq!(tensor.is_empty(), true);
286    /// ```
287    #[inline]
288    pub fn is_empty(&self) -> bool {
289        self.tensors.is_empty()
290            && self.legs.is_empty()
291            && matches!(*self.tensor_data(), TensorData::Uncontracted)
292    }
293
294    /// Pushes additional `tensor` into this tensor, which must be a composite tensor.
295    #[inline]
296    pub fn push_tensor(&mut self, tensor: Self) {
297        assert!(
298            self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
299            "Cannot push tensors into a leaf tensor"
300        );
301        self.tensors.push(tensor);
302    }
303
304    /// Pushes additional `tensors` into this tensor, which must be a composite tensor.
305    #[inline]
306    pub fn push_tensors(&mut self, mut tensors: Vec<Self>) {
307        assert!(
308            self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
309            "Cannot push tensors into a leaf tensor"
310        );
311        self.tensors.append(&mut tensors);
312    }
313
314    /// Getter for tensor data.
315    #[inline]
316    pub fn tensor_data(&self) -> &TensorData {
317        &self.tensordata
318    }
319
320    /// Setter for tensor data.
321    ///
322    /// # Examples
323    ///
324    /// ```
325    /// # use tnc::tensornetwork::tensor::Tensor;
326    /// # use tnc::tensornetwork::tensordata::TensorData;
327    /// let mut tensor = Tensor::new_from_const(vec![0, 1], 2);
328    /// let tensordata = TensorData::Gate((String::from("x"), vec![], false));
329    /// tensor.set_tensor_data(tensordata);
330    /// ```
331    #[inline]
332    pub fn set_tensor_data(&mut self, tensordata: TensorData) {
333        assert!(
334            self.is_leaf() || matches!(tensordata, TensorData::Uncontracted),
335            "Cannot add data to composite tensor"
336        );
337        self.tensordata = tensordata;
338    }
339
340    /// Returns whether all tensors inside this tensor are connected.
341    /// This only checks the top-level, not recursing into composite tensors.
342    ///
343    /// # Examples
344    /// ```
345    /// # use tnc::tensornetwork::tensor::Tensor;
346    /// # use rustc_hash::FxHashMap;
347    /// // Create a tensor network with two connected tensors
348    /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 5)]);
349    /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
350    /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
351    /// let mut tn = Tensor::new_composite(vec![v1, v2]);
352    /// assert!(tn.is_connected());
353    ///
354    /// // Introduce a new tensor that is not connected
355    /// let v3 = Tensor::new_from_map(vec![3], &bond_dims);
356    /// tn.push_tensor(v3);
357    /// assert!(!tn.is_connected());
358    /// ```
359    pub fn is_connected(&self) -> bool {
360        let num_tensors = self.tensors.len();
361        let mut uf = UnionFind::new(num_tensors);
362
363        for t1_id in 0..num_tensors - 1 {
364            for t2_id in (t1_id + 1)..num_tensors {
365                let t1 = &self.tensors[t1_id];
366                let t2 = &self.tensors[t2_id];
367                if !(t1 & t2).legs.is_empty() {
368                    uf.union(t1_id, t2_id);
369                }
370            }
371        }
372
373        uf.count_sets() == 1
374    }
375
376    /// Returns `Tensor` with legs in `self` that are not in `other`.
377    ///
378    /// # Examples
379    /// ```
380    /// # use tnc::tensornetwork::tensor::Tensor;
381    /// # use rustc_hash::FxHashMap;
382    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
383    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
384    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
385    /// let diff_tensor = &tensor1 - &tensor2;
386    /// assert_eq!(diff_tensor.legs(), &[1, 3]);
387    /// assert_eq!(diff_tensor.bond_dims(), &[2, 6]);
388    /// ```
389    #[must_use]
390    pub fn difference(&self, other: &Self) -> Self {
391        let mut new_legs = Vec::with_capacity(self.legs.len());
392        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
393        for (leg, dim) in self.edges() {
394            if !other.legs.contains(leg) {
395                new_legs.push(*leg);
396                new_bond_dims.push(*dim);
397            }
398        }
399        Self::new(new_legs, new_bond_dims)
400    }
401
402    /// Returns `Tensor` with union of legs in both `self` and `other`.
403    ///
404    /// # Examples
405    /// ```
406    /// # use tnc::tensornetwork::tensor::Tensor;
407    /// # use rustc_hash::FxHashMap;
408    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
409    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
410    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
411    /// let union_tensor = &tensor1 | &tensor2;
412    /// assert_eq!(union_tensor.legs(), &[1, 2, 3, 4, 5]);
413    /// assert_eq!(union_tensor.bond_dims(), &[2, 4, 6, 3, 9]);
414    /// ```
415    #[must_use]
416    pub fn union(&self, other: &Self) -> Self {
417        let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
418        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
419        new_legs.extend_from_slice(&self.legs);
420        new_bond_dims.extend_from_slice(&self.bond_dims);
421        for (leg, dim) in other.edges() {
422            if !self.legs.contains(leg) {
423                new_legs.push(*leg);
424                new_bond_dims.push(*dim);
425            }
426        }
427        Self::new(new_legs, new_bond_dims)
428    }
429
430    /// Returns `Tensor` with intersection of legs in `self` and `other`.
431    ///
432    /// # Examples
433    /// ```
434    /// # use tnc::tensornetwork::tensor::Tensor;
435    /// # use rustc_hash::FxHashMap;
436    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
437    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
438    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
439    /// let intersection_tensor = &tensor1 & &tensor2;
440    /// assert_eq!(intersection_tensor.legs(), &[2]);
441    /// assert_eq!(intersection_tensor.bond_dims(), &[4]);
442    /// ```
443    #[must_use]
444    pub fn intersection(&self, other: &Self) -> Self {
445        let mut new_legs = Vec::with_capacity(self.legs.len().min(other.legs.len()));
446        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
447        for (leg, dim) in self.edges() {
448            if other.legs.contains(leg) {
449                new_legs.push(*leg);
450                new_bond_dims.push(*dim);
451            }
452        }
453        Self::new(new_legs, new_bond_dims)
454    }
455
456    /// Returns `Tensor` with symmetrical difference of legs in `self` and `other`.
457    ///
458    /// # Examples
459    /// ```
460    /// # use tnc::tensornetwork::tensor::Tensor;
461    /// # use rustc_hash::FxHashMap;
462    /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
463    /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
464    /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
465    /// let sym_dif_tensor = &tensor1 ^ &tensor2;
466    /// assert_eq!(sym_dif_tensor.legs(), &[1, 3, 4, 5]);
467    /// assert_eq!(sym_dif_tensor.bond_dims(), &[2, 6, 3, 9]);
468    /// ```
469    #[must_use]
470    pub fn symmetric_difference(&self, other: &Self) -> Self {
471        let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
472        let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
473        for (leg, dim) in self.edges() {
474            if !other.legs.contains(leg) {
475                new_legs.push(*leg);
476                new_bond_dims.push(*dim);
477            }
478        }
479        for (leg, dim) in other.edges() {
480            if !self.legs.contains(leg) {
481                new_legs.push(*leg);
482                new_bond_dims.push(*dim);
483            }
484        }
485        Self::new(new_legs, new_bond_dims)
486    }
487
488    /// Get output legs after tensor contraction
489    pub fn external_tensor(&self) -> Tensor {
490        if self.is_leaf() {
491            return self.clone();
492        }
493
494        let mut ext_tensor = Self::default();
495        for tensor in &self.tensors {
496            let new_tensor = if tensor.is_composite() {
497                &tensor.external_tensor()
498            } else {
499                tensor
500            };
501            ext_tensor = &ext_tensor ^ new_tensor;
502        }
503
504        ext_tensor
505    }
506}
507
508impl ApproxEq for &Tensor {
509    type Margin = F64Margin;
510
511    fn approx_eq<M: Into<Self::Margin>>(self, other: Self, margin: M) -> bool {
512        let margin = margin.into();
513        if self.legs != other.legs {
514            return false;
515        }
516        if self.bond_dims != other.bond_dims {
517            return false;
518        }
519        if self.tensors.len() != other.tensors.len() {
520            return false;
521        }
522        for (tensor, other_tensor) in zip(&self.tensors, &other.tensors) {
523            if !tensor.approx_eq(other_tensor, margin) {
524                return false;
525            }
526        }
527
528        self.tensordata.approx_eq(&other.tensordata, margin)
529    }
530}
531
532impl BitOr for &Tensor {
533    type Output = Tensor;
534    #[inline]
535    fn bitor(self, rhs: &Tensor) -> Tensor {
536        self.union(rhs)
537    }
538}
539
540impl BitAnd for &Tensor {
541    type Output = Tensor;
542    #[inline]
543    fn bitand(self, rhs: &Tensor) -> Tensor {
544        self.intersection(rhs)
545    }
546}
547
548impl BitXor for &Tensor {
549    type Output = Tensor;
550    #[inline]
551    fn bitxor(self, rhs: &Tensor) -> Tensor {
552        self.symmetric_difference(rhs)
553    }
554}
555
556impl Sub for &Tensor {
557    type Output = Tensor;
558    #[inline]
559    fn sub(self, rhs: &Tensor) -> Tensor {
560        self.difference(rhs)
561    }
562}
563
564impl BitXorAssign<&Tensor> for Tensor {
565    #[inline]
566    fn bitxor_assign(&mut self, rhs: &Tensor) {
567        *self = self.symmetric_difference(rhs);
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    use std::iter::zip;
576
577    use rustc_hash::FxHashMap;
578
579    use crate::tensornetwork::tensordata::TensorData;
580
581    macro_rules! assert_matches {
582        ($left:expr, $pattern:pat) => {
583            match $left {
584                $pattern => (),
585                _ => panic!(
586                    "Expected pattern {} but got {:?}",
587                    stringify!($pattern),
588                    $left
589                ),
590            }
591        };
592    }
593
594    #[test]
595    fn test_empty_tensor() {
596        let tensor = Tensor::default();
597        assert!(tensor.tensors.is_empty());
598        assert!(tensor.legs.is_empty());
599        assert!(tensor.bond_dims.is_empty());
600        assert!(tensor.is_empty());
601    }
602
603    #[test]
604    fn test_new() {
605        let tensor = Tensor::new(vec![2, 4, 5], vec![4, 2, 6]);
606        assert_eq!(tensor.legs(), &[2, 4, 5]);
607        assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
608        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
609    }
610
611    #[test]
612    fn test_new_from_map() {
613        let bond_dims = FxHashMap::from_iter([(1, 1), (2, 4), (3, 7), (4, 2), (5, 6)]);
614        let tensor = Tensor::new_from_map(vec![2, 4, 5], &bond_dims);
615        assert_eq!(tensor.legs(), &[2, 4, 5]);
616        assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
617        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
618    }
619
620    #[test]
621    fn test_new_from_const() {
622        let tensor = Tensor::new_from_const(vec![9, 2, 5, 1], 3);
623        assert_eq!(tensor.legs(), &[9, 2, 5, 1]);
624        assert_eq!(tensor.bond_dims(), &[3, 3, 3, 3]);
625        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
626    }
627
628    #[test]
629    fn test_external_tensor() {
630        let bond_dims = FxHashMap::from_iter([
631            (2, 2),
632            (3, 4),
633            (4, 6),
634            (5, 8),
635            (6, 10),
636            (7, 12),
637            (8, 14),
638            (9, 16),
639        ]);
640        let tensor_1 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
641        let tensor_2 = Tensor::new_from_map(vec![2, 3, 5], &bond_dims);
642        let tensor_12 = Tensor::new_composite(vec![tensor_1, tensor_2]);
643
644        let tensor_3 = Tensor::new_from_map(vec![6, 7, 8], &bond_dims);
645        let tensor_4 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
646        let tensor_34 = Tensor::new_composite(vec![tensor_3, tensor_4]);
647
648        let tensor_1234 = Tensor::new_composite(vec![tensor_12, tensor_34]);
649
650        let external = tensor_1234.external_tensor();
651        assert_eq!(external.legs(), &[4, 5, 7, 9]);
652        assert_eq!(external.bond_dims(), &[6, 8, 12, 16]);
653    }
654
655    #[test]
656    fn test_push_tensor() {
657        let bond_dims =
658            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
659        let ref_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
660        let ref_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
661
662        let mut tensor = Tensor::default();
663
664        // Push tensor 1
665        let tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
666        tensor.push_tensor(tensor_1);
667
668        for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1]) {
669            assert_eq!(sub_tensor.legs(), ref_tensor.legs());
670            assert_eq!(sub_tensor.bond_dims(), ref_tensor.bond_dims());
671        }
672
673        // Push tensor 2
674        let tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
675        tensor.push_tensor(tensor_2);
676
677        for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1, &ref_tensor_2]) {
678            assert_eq!(sub_tensor.legs(), ref_tensor.legs());
679            assert_eq!(sub_tensor.bond_dims(), ref_tensor.bond_dims());
680        }
681
682        // Test that other fields are unchanged
683        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
684        assert!(tensor.legs().is_empty());
685    }
686
687    #[test]
688    #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
689    fn test_push_tensor_to_leaf() {
690        let bond_dims =
691            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
692        let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
693        let pushed_tensor = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
694        leaf_tensor.push_tensor(pushed_tensor);
695    }
696
697    #[test]
698    fn test_push_tensors() {
699        let bond_dims =
700            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
701        let ref_tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
702        let ref_tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
703        let ref_tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
704
705        let tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
706        let tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
707        let tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
708        let mut tensor = Tensor::default();
709        tensor.push_tensors(vec![tensor_1, tensor_2, tensor_3]);
710
711        assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
712
713        for (sub_tensor, ref_tensor) in zip(
714            tensor.tensors(),
715            &vec![ref_tensor_1, ref_tensor_2, ref_tensor_3],
716        ) {
717            assert_eq!(sub_tensor.legs(), ref_tensor.legs());
718            assert_eq!(sub_tensor.bond_dims(), ref_tensor.bond_dims());
719        }
720    }
721
722    #[test]
723    #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
724    fn test_push_tensors_to_leaf() {
725        let bond_dims =
726            FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
727        let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
728        let pushed_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
729        let pushed_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
730
731        leaf_tensor.push_tensors(vec![pushed_tensor_1, pushed_tensor_2]);
732    }
733}