tnc/tensornetwork/tensor.rs
1use std::iter::zip;
2use std::num::TryFromIntError;
3use std::ops::{BitAnd, BitOr, BitXor, BitXorAssign, Sub};
4
5use approx::AbsDiffEq;
6use rustc_hash::FxHashMap;
7use serde::{Deserialize, Serialize};
8
9use crate::tensornetwork::tensordata::TensorData;
10use crate::utils::datastructures::UnionFind;
11
12/// Unique index of a leg.
13pub type EdgeIndex = usize;
14
15/// Index of a tensor in a tensor network.
16pub type TensorIndex = usize;
17
18/// Abstract representation of a tensor. Can be a *composite* tensor (that is, a
19/// tensor network) or a *leaf* tensor.
20#[derive(Default, Debug, Clone, Serialize, Deserialize)]
21pub struct Tensor {
22 /// The inner tensors that make up this tensor. If non-empty, this tensor is
23 /// called a *composite* tensor.
24 pub(crate) tensors: Vec<Tensor>,
25
26 /// The legs of the tensor. Each leg should have a unique id. Connected tensors
27 /// are recognized by having at least one leg id in common.
28 pub(crate) legs: Vec<EdgeIndex>,
29
30 /// The bond dimensions of the legs, same length and order as `legs`. It is
31 /// assumed (but not checked!) that the bond dimensions of different tensors that
32 /// connect to the same leg match.
33 pub(crate) bond_dims: Vec<u64>,
34
35 /// The data of the tensor.
36 pub(crate) tensordata: TensorData,
37}
38
39impl Tensor {
40 /// Constructs a Tensor object with the given `legs` (edge ids) and corresponding
41 /// `bond_dims`. The tensor doesn't have underlying data.
42 #[inline]
43 pub(crate) fn new(legs: Vec<EdgeIndex>, bond_dims: Vec<u64>) -> Self {
44 assert_eq!(legs.len(), bond_dims.len());
45 Self {
46 legs,
47 bond_dims,
48 ..Default::default()
49 }
50 }
51
52 /// Constructs a Tensor using with the given edge ids and a mapping of edge ids
53 /// to corresponding bond dimension.
54 ///
55 /// # Examples
56 /// ```
57 /// # use tnc::tensornetwork::tensor::Tensor;
58 /// # use rustc_hash::FxHashMap;
59 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
60 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
61 /// assert_eq!(tensor.legs(), &[1, 2, 3]);
62 /// assert_eq!(tensor.bond_dims(), &[2, 4, 6]);
63 /// ```
64 #[inline]
65 pub fn new_from_map(legs: Vec<EdgeIndex>, bond_dims_map: &FxHashMap<EdgeIndex, u64>) -> Self {
66 let bond_dims = legs.iter().map(|l| bond_dims_map[l]).collect();
67 Self::new(legs, bond_dims)
68 }
69
70 /// Constructs a Tensor with the given edge ids and the same bond dimension for
71 /// all edges.
72 ///
73 /// # Examples
74 /// ```
75 /// # use tnc::tensornetwork::tensor::Tensor;
76 /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 2);
77 /// assert_eq!(tensor.legs(), &[1, 2, 3]);
78 /// assert_eq!(tensor.bond_dims(), &[2, 2, 2]);
79 /// ```
80 #[inline]
81 pub fn new_from_const(legs: Vec<EdgeIndex>, bond_dim: u64) -> Self {
82 let bond_dims = vec![bond_dim; legs.len()];
83 Self::new(legs, bond_dims)
84 }
85
86 /// Creates a new composite tensor with the given nested tensors.
87 #[inline]
88 pub fn new_composite(tensors: Vec<Self>) -> Self {
89 Self {
90 tensors,
91 ..Default::default()
92 }
93 }
94
95 /// Returns edge ids of Tensor object.
96 ///
97 /// # Examples
98 /// ```
99 /// # use tnc::tensornetwork::tensor::Tensor;
100 /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 3);
101 /// assert_eq!(tensor.legs(), &[1, 2, 3]);
102 /// ```
103 #[inline]
104 pub fn legs(&self) -> &Vec<EdgeIndex> {
105 &self.legs
106 }
107
108 /// Internal method to set legs. Needs pub(crate) for contraction order finding for hierarchies.
109 #[inline]
110 pub(crate) fn set_legs(&mut self, legs: Vec<EdgeIndex>) {
111 self.legs = legs;
112 }
113
114 /// Returns an iterator of tuples of leg ids and their corresponding bond size.
115 #[inline]
116 pub fn edges(&self) -> impl Iterator<Item = (&EdgeIndex, &u64)> + '_ {
117 std::iter::zip(&self.legs, &self.bond_dims)
118 }
119
120 /// Returns the nested tensors of a composite tensor.
121 ///
122 /// # Examples
123 ///
124 /// ```
125 /// # use tnc::tensornetwork::tensor::Tensor;
126 /// # use rustc_hash::FxHashMap;
127 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
128 /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
129 /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
130 /// let tn = Tensor::new_composite(vec![v1.clone(), v2.clone()]);
131 /// for (tensor, ref_tensor) in std::iter::zip(tn.tensors(), vec![v1, v2]){
132 /// assert_eq!(tensor.legs(), ref_tensor.legs());
133 /// }
134 /// ```
135 #[inline]
136 pub fn tensors(&self) -> &Vec<Self> {
137 &self.tensors
138 }
139
140 /// Gets a nested `Tensor` based on the `nested_indices` which specify the index
141 /// of the tensor at each level of the hierarchy.
142 ///
143 /// # Examples
144 /// ```
145 /// # use tnc::tensornetwork::tensor::Tensor;
146 /// # use rustc_hash::FxHashMap;
147 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 2), (4, 1)]);
148 /// let mut v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
149 /// let mut v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
150 /// let mut v3 = Tensor::new_from_map(vec![2, 3], &bond_dims);
151 /// let mut v4 = Tensor::new_from_map(vec![3, 4], &bond_dims);
152 /// let tn1 = Tensor::new_composite(vec![v1, v2]);
153 /// let tn2 = Tensor::new_composite(vec![v3.clone(), v4]);
154 /// let nested_tn = Tensor::new_composite(vec![tn1, tn2]);
155 ///
156 /// assert_eq!(nested_tn.nested_tensor(&[1, 0]).legs(), v3.legs());
157 /// ```
158 pub fn nested_tensor(&self, nested_indices: &[usize]) -> &Tensor {
159 let mut tensor = self;
160 for index in nested_indices {
161 tensor = tensor.tensor(*index);
162 }
163 tensor
164 }
165
166 /// Returns the total number of leaf tensors in the hierarchy.
167 pub fn total_num_tensors(&self) -> usize {
168 if self.is_composite() {
169 self.tensors.iter().map(Self::total_num_tensors).sum()
170 } else {
171 1
172 }
173 }
174
175 /// Get ith Tensor.
176 ///
177 /// # Examples
178 ///
179 /// ```
180 /// # use tnc::tensornetwork::tensor::Tensor;
181 /// # use rustc_hash::FxHashMap;
182 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
183 /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
184 /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
185 /// let tn = Tensor::new_composite(vec![v1.clone(), v2]);
186 /// assert_eq!(tn.tensor(0).legs(), v1.legs());
187 /// ```
188 #[inline]
189 pub fn tensor(&self, i: TensorIndex) -> &Self {
190 &self.tensors[i]
191 }
192
193 /// Getter for bond dimensions.
194 #[inline]
195 pub fn bond_dims(&self) -> &Vec<u64> {
196 assert!(self.is_leaf());
197 &self.bond_dims
198 }
199
200 /// Returns the shape of tensor. This is the same as the bond dimensions, but as
201 /// `usize`. The conversion can fail, hence a [`Result`] is returned.
202 pub fn shape(&self) -> Result<Vec<usize>, TryFromIntError> {
203 self.bond_dims.iter().map(|&dim| dim.try_into()).collect()
204 }
205
206 /// Returns the number of dimensions.
207 ///
208 /// # Examples
209 /// ```
210 /// # use tnc::tensornetwork::tensor::Tensor;
211 /// # use rustc_hash::FxHashMap;
212 /// let bond_dims = FxHashMap::from_iter([(1, 4), (2, 6), (3, 2)]);
213 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
214 /// assert_eq!(tensor.dims(), 3);
215 /// ```
216 #[inline]
217 pub fn dims(&self) -> usize {
218 self.legs.len()
219 }
220
221 /// Returns the number of elements. This is a f64 to avoid overflow in large
222 /// tensors.
223 ///
224 /// # Examples
225 /// ```
226 /// # use tnc::tensornetwork::tensor::Tensor;
227 /// # use rustc_hash::FxHashMap;
228 /// let bond_dims = FxHashMap::from_iter([(1, 5), (2, 15), (3, 8)]);
229 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
230 /// assert_eq!(tensor.size(), 600.0);
231 /// ```
232 #[inline]
233 pub fn size(&self) -> f64 {
234 self.bond_dims.iter().map(|v| *v as f64).product()
235 }
236
237 /// Returns true if Tensor is a leaf tensor, without any nested tensors.
238 ///
239 /// # Examples
240 /// ```
241 /// # use tnc::tensornetwork::tensor::Tensor;
242 /// # use rustc_hash::FxHashMap;
243 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
244 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
245 /// assert_eq!(tensor.is_leaf(), true);
246 /// let comp = Tensor::new_composite(vec![tensor]);
247 /// assert_eq!(comp.is_leaf(), false);
248 /// ```
249 #[inline]
250 pub fn is_leaf(&self) -> bool {
251 self.tensors.is_empty()
252 }
253
254 /// Returns true if Tensor is composite.
255 ///
256 /// # Examples
257 /// ```
258 /// # use tnc::tensornetwork::tensor::Tensor;
259 /// # use rustc_hash::FxHashMap;
260 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
261 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
262 /// assert_eq!(tensor.is_composite(), false);
263 /// let comp = Tensor::new_composite(vec![tensor]);
264 /// assert_eq!(comp.is_composite(), true);
265 /// ```
266 #[inline]
267 pub fn is_composite(&self) -> bool {
268 !self.tensors.is_empty()
269 }
270
271 /// Returns true if Tensor is empty. This means, it doesn't have any subtensors,
272 /// has no legs and is doesn't have any data (e.g., is not a scalar).
273 ///
274 /// # Examples
275 /// ```
276 /// # use tnc::tensornetwork::tensor::Tensor;
277 /// let tensor = Tensor::default();
278 /// assert_eq!(tensor.is_empty(), true);
279 /// ```
280 #[inline]
281 pub fn is_empty(&self) -> bool {
282 self.tensors.is_empty()
283 && self.legs.is_empty()
284 && matches!(*self.tensor_data(), TensorData::Uncontracted)
285 }
286
287 /// Consumes the `Tensor` and returns its `TensorData`.
288 ///
289 /// # Examples
290 /// ```
291 /// # use tnc::tensornetwork::tensor::Tensor;
292 /// # use tnc::tensornetwork::tensordata::TensorData;
293 /// let tensor = Tensor::new_from_const(vec![1, 2], 2);
294 /// let tensordata = tensor.into_tensor_data();
295 /// assert_eq!(tensordata, TensorData::Uncontracted);
296 /// ```
297 #[inline]
298 pub fn into_tensor_data(self) -> TensorData {
299 self.tensordata
300 }
301
302 /// Pushes additional `tensor` into this tensor, which must be a composite tensor.
303 #[inline]
304 pub fn push_tensor(&mut self, tensor: Self) {
305 assert!(
306 self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
307 "Cannot push tensors into a leaf tensor"
308 );
309 self.tensors.push(tensor);
310 }
311
312 /// Pushes additional `tensors` into this tensor, which must be a composite tensor.
313 #[inline]
314 pub fn push_tensors(&mut self, mut tensors: Vec<Self>) {
315 assert!(
316 self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
317 "Cannot push tensors into a leaf tensor"
318 );
319 self.tensors.append(&mut tensors);
320 }
321
322 /// Getter for tensor data.
323 #[inline]
324 pub fn tensor_data(&self) -> &TensorData {
325 &self.tensordata
326 }
327
328 /// Setter for tensor data.
329 ///
330 /// # Examples
331 ///
332 /// ```
333 /// # use tnc::tensornetwork::tensor::Tensor;
334 /// # use tnc::tensornetwork::tensordata::TensorData;
335 /// let mut tensor = Tensor::new_from_const(vec![0, 1], 2);
336 /// let tensordata = TensorData::Gate((String::from("x"), vec![], false));
337 /// tensor.set_tensor_data(tensordata);
338 /// ```
339 #[inline]
340 pub fn set_tensor_data(&mut self, tensordata: TensorData) {
341 assert!(
342 self.is_leaf() || matches!(tensordata, TensorData::Uncontracted),
343 "Cannot add data to composite tensor"
344 );
345 self.tensordata = tensordata;
346 }
347
348 /// Returns whether all tensors inside this tensor are connected.
349 /// This only checks the top-level, not recursing into composite tensors.
350 ///
351 /// # Examples
352 /// ```
353 /// # use tnc::tensornetwork::tensor::Tensor;
354 /// # use rustc_hash::FxHashMap;
355 /// // Create a tensor network with two connected tensors
356 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 5)]);
357 /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
358 /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
359 /// let mut tn = Tensor::new_composite(vec![v1, v2]);
360 /// assert!(tn.is_connected());
361 ///
362 /// // Introduce a new tensor that is not connected
363 /// let v3 = Tensor::new_from_map(vec![3], &bond_dims);
364 /// tn.push_tensor(v3);
365 /// assert!(!tn.is_connected());
366 /// ```
367 pub fn is_connected(&self) -> bool {
368 let num_tensors = self.tensors.len();
369 let mut uf = UnionFind::new(num_tensors);
370
371 for t1_id in 0..num_tensors {
372 for t2_id in (t1_id + 1)..num_tensors {
373 let t1 = &self.tensors[t1_id];
374 let t2 = &self.tensors[t2_id];
375 if !(t1 & t2).legs.is_empty() {
376 uf.union(t1_id, t2_id);
377 }
378 }
379 }
380
381 uf.count_sets() == 1
382 }
383
384 /// Returns `Tensor` with legs in `self` that are not in `other`.
385 ///
386 /// # Examples
387 /// ```
388 /// # use tnc::tensornetwork::tensor::Tensor;
389 /// # use rustc_hash::FxHashMap;
390 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
391 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
392 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
393 /// let diff_tensor = &tensor1 - &tensor2;
394 /// assert_eq!(diff_tensor.legs(), &[1, 3]);
395 /// assert_eq!(diff_tensor.bond_dims(), &[2, 6]);
396 /// ```
397 #[must_use]
398 pub fn difference(&self, other: &Self) -> Self {
399 let mut new_legs = Vec::with_capacity(self.legs.len());
400 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
401 for (leg, dim) in self.edges() {
402 if !other.legs.contains(leg) {
403 new_legs.push(*leg);
404 new_bond_dims.push(*dim);
405 }
406 }
407 Self::new(new_legs, new_bond_dims)
408 }
409
410 /// Returns `Tensor` with union of legs in both `self` and `other`.
411 ///
412 /// # Examples
413 /// ```
414 /// # use tnc::tensornetwork::tensor::Tensor;
415 /// # use rustc_hash::FxHashMap;
416 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
417 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
418 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
419 /// let union_tensor = &tensor1 | &tensor2;
420 /// assert_eq!(union_tensor.legs(), &[1, 2, 3, 4, 5]);
421 /// assert_eq!(union_tensor.bond_dims(), &[2, 4, 6, 3, 9]);
422 /// ```
423 #[must_use]
424 pub fn union(&self, other: &Self) -> Self {
425 let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
426 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
427 new_legs.extend_from_slice(&self.legs);
428 new_bond_dims.extend_from_slice(&self.bond_dims);
429 for (leg, dim) in other.edges() {
430 if !self.legs.contains(leg) {
431 new_legs.push(*leg);
432 new_bond_dims.push(*dim);
433 }
434 }
435 Self::new(new_legs, new_bond_dims)
436 }
437
438 /// Returns `Tensor` with intersection of legs in `self` and `other`.
439 ///
440 /// # Examples
441 /// ```
442 /// # use tnc::tensornetwork::tensor::Tensor;
443 /// # use rustc_hash::FxHashMap;
444 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
445 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
446 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
447 /// let intersection_tensor = &tensor1 & &tensor2;
448 /// assert_eq!(intersection_tensor.legs(), &[2]);
449 /// assert_eq!(intersection_tensor.bond_dims(), &[4]);
450 /// ```
451 #[must_use]
452 pub fn intersection(&self, other: &Self) -> Self {
453 let mut new_legs = Vec::with_capacity(self.legs.len().min(other.legs.len()));
454 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
455 for (leg, dim) in self.edges() {
456 if other.legs.contains(leg) {
457 new_legs.push(*leg);
458 new_bond_dims.push(*dim);
459 }
460 }
461 Self::new(new_legs, new_bond_dims)
462 }
463
464 /// Returns `Tensor` with symmetrical difference of legs in `self` and `other`.
465 ///
466 /// # Examples
467 /// ```
468 /// # use tnc::tensornetwork::tensor::Tensor;
469 /// # use rustc_hash::FxHashMap;
470 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
471 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
472 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
473 /// let sym_dif_tensor = &tensor1 ^ &tensor2;
474 /// assert_eq!(sym_dif_tensor.legs(), &[1, 3, 4, 5]);
475 /// assert_eq!(sym_dif_tensor.bond_dims(), &[2, 6, 3, 9]);
476 /// ```
477 #[must_use]
478 pub fn symmetric_difference(&self, other: &Self) -> Self {
479 let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
480 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
481 for (leg, dim) in self.edges() {
482 if !other.legs.contains(leg) {
483 new_legs.push(*leg);
484 new_bond_dims.push(*dim);
485 }
486 }
487 for (leg, dim) in other.edges() {
488 if !self.legs.contains(leg) {
489 new_legs.push(*leg);
490 new_bond_dims.push(*dim);
491 }
492 }
493 Self::new(new_legs, new_bond_dims)
494 }
495
496 /// Get output legs after tensor contraction
497 pub fn external_tensor(&self) -> Tensor {
498 if self.is_leaf() {
499 return self.clone();
500 }
501
502 let mut ext_tensor = Self::default();
503 for tensor in &self.tensors {
504 let new_tensor = if tensor.is_composite() {
505 &tensor.external_tensor()
506 } else {
507 tensor
508 };
509 ext_tensor = &ext_tensor ^ new_tensor;
510 }
511
512 ext_tensor
513 }
514}
515
516impl PartialEq for Tensor {
517 fn eq(&self, other: &Self) -> bool {
518 self.legs == other.legs
519 && self.bond_dims == other.bond_dims
520 && self.tensors == other.tensors
521 && self.tensordata == other.tensordata
522 }
523}
524
525impl AbsDiffEq for Tensor {
526 type Epsilon = f64;
527
528 fn default_epsilon() -> Self::Epsilon {
529 f64::default_epsilon()
530 }
531
532 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
533 if self.legs != other.legs {
534 return false;
535 }
536 if self.bond_dims != other.bond_dims {
537 return false;
538 }
539 if self.tensors.len() != other.tensors.len() {
540 return false;
541 }
542 for (tensor, other_tensor) in zip(&self.tensors, &other.tensors) {
543 if !Tensor::abs_diff_eq(tensor, other_tensor, epsilon) {
544 return false;
545 }
546 }
547
548 TensorData::abs_diff_eq(&self.tensordata, &other.tensordata, epsilon)
549 }
550}
551
552impl BitOr for &Tensor {
553 type Output = Tensor;
554 #[inline]
555 fn bitor(self, rhs: &Tensor) -> Tensor {
556 self.union(rhs)
557 }
558}
559
560impl BitAnd for &Tensor {
561 type Output = Tensor;
562 #[inline]
563 fn bitand(self, rhs: &Tensor) -> Tensor {
564 self.intersection(rhs)
565 }
566}
567
568impl BitXor for &Tensor {
569 type Output = Tensor;
570 #[inline]
571 fn bitxor(self, rhs: &Tensor) -> Tensor {
572 self.symmetric_difference(rhs)
573 }
574}
575
576impl Sub for &Tensor {
577 type Output = Tensor;
578 #[inline]
579 fn sub(self, rhs: &Tensor) -> Tensor {
580 self.difference(rhs)
581 }
582}
583
584impl BitXorAssign<&Tensor> for Tensor {
585 #[inline]
586 fn bitxor_assign(&mut self, rhs: &Tensor) {
587 *self = self.symmetric_difference(rhs);
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 use std::iter::zip;
596
597 use approx::assert_abs_diff_eq;
598 use rustc_hash::FxHashMap;
599
600 use crate::tensornetwork::tensordata::TensorData;
601
602 macro_rules! assert_matches {
603 ($left:expr, $pattern:pat) => {
604 match $left {
605 $pattern => (),
606 _ => panic!(
607 "Expected pattern {} but got {:?}",
608 stringify!($pattern),
609 $left
610 ),
611 }
612 };
613 }
614
615 #[test]
616 fn test_empty_tensor() {
617 let tensor = Tensor::default();
618 assert!(tensor.tensors.is_empty());
619 assert!(tensor.legs.is_empty());
620 assert!(tensor.bond_dims.is_empty());
621 assert!(tensor.is_empty());
622 }
623
624 #[test]
625 fn test_new() {
626 let tensor = Tensor::new(vec![2, 4, 5], vec![4, 2, 6]);
627 assert_eq!(tensor.legs(), &[2, 4, 5]);
628 assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
629 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
630 }
631
632 #[test]
633 fn test_new_from_map() {
634 let bond_dims = FxHashMap::from_iter([(1, 1), (2, 4), (3, 7), (4, 2), (5, 6)]);
635 let tensor = Tensor::new_from_map(vec![2, 4, 5], &bond_dims);
636 assert_eq!(tensor.legs(), &[2, 4, 5]);
637 assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
638 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
639 }
640
641 #[test]
642 fn test_new_from_const() {
643 let tensor = Tensor::new_from_const(vec![9, 2, 5, 1], 3);
644 assert_eq!(tensor.legs(), &[9, 2, 5, 1]);
645 assert_eq!(tensor.bond_dims(), &[3, 3, 3, 3]);
646 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
647 }
648
649 #[test]
650 fn test_external_tensor() {
651 let bond_dims = FxHashMap::from_iter([
652 (2, 2),
653 (3, 4),
654 (4, 6),
655 (5, 8),
656 (6, 10),
657 (7, 12),
658 (8, 14),
659 (9, 16),
660 ]);
661 let tensor_1 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
662 let tensor_2 = Tensor::new_from_map(vec![2, 3, 5], &bond_dims);
663 let tensor_12 = Tensor::new_composite(vec![tensor_1, tensor_2]);
664
665 let tensor_3 = Tensor::new_from_map(vec![6, 7, 8], &bond_dims);
666 let tensor_4 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
667 let tensor_34 = Tensor::new_composite(vec![tensor_3, tensor_4]);
668
669 let tensor_1234 = Tensor::new_composite(vec![tensor_12, tensor_34]);
670
671 let external = tensor_1234.external_tensor();
672
673 let ref_tensor = Tensor::new_from_map(vec![4, 5, 7, 9], &bond_dims);
674 assert_abs_diff_eq!(&external, &ref_tensor);
675 }
676
677 #[test]
678 fn test_push_tensor() {
679 let bond_dims =
680 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
681 let ref_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
682 let ref_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
683
684 let mut tensor = Tensor::default();
685
686 // Push tensor 1
687 let tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
688 tensor.push_tensor(tensor_1);
689
690 for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1]) {
691 assert_abs_diff_eq!(sub_tensor, ref_tensor);
692 }
693
694 // Push tensor 2
695 let tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
696 tensor.push_tensor(tensor_2);
697
698 for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1, &ref_tensor_2]) {
699 assert_abs_diff_eq!(sub_tensor, ref_tensor);
700 }
701
702 // Test that other fields are unchanged
703 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
704 assert!(tensor.legs().is_empty());
705 }
706
707 #[test]
708 #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
709 fn test_push_tensor_to_leaf() {
710 let bond_dims =
711 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
712 let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
713 let pushed_tensor = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
714 leaf_tensor.push_tensor(pushed_tensor);
715 }
716
717 #[test]
718 fn test_push_tensors() {
719 let bond_dims =
720 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
721 let ref_tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
722 let ref_tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
723 let ref_tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
724
725 let tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
726 let tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
727 let tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
728 let mut tensor = Tensor::default();
729 tensor.push_tensors(vec![tensor_1, tensor_2, tensor_3]);
730
731 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
732
733 for (sub_tensor, ref_tensor) in zip(
734 tensor.tensors(),
735 &vec![ref_tensor_1, ref_tensor_2, ref_tensor_3],
736 ) {
737 assert_abs_diff_eq!(sub_tensor, ref_tensor);
738 }
739 }
740
741 #[test]
742 #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
743 fn test_push_tensors_to_leaf() {
744 let bond_dims =
745 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
746 let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
747 let pushed_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
748 let pushed_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
749
750 leaf_tensor.push_tensors(vec![pushed_tensor_1, pushed_tensor_2]);
751 }
752}