tnc/tensornetwork/tensor.rs
1use std::hash::{Hash, Hasher};
2use std::iter::zip;
3use std::num::TryFromIntError;
4use std::ops::{BitAnd, BitOr, BitXor, BitXorAssign, Sub};
5
6use float_cmp::{ApproxEq, F64Margin};
7use rustc_hash::FxHashMap;
8use serde::{Deserialize, Serialize};
9
10use crate::tensornetwork::tensordata::TensorData;
11use crate::utils::datastructures::UnionFind;
12
13/// Unique index of a leg.
14pub type EdgeIndex = usize;
15
16/// Index of a tensor in a tensor network.
17pub type TensorIndex = usize;
18
19/// Abstract representation of a tensor. Can be a *composite* tensor (that is, a
20/// tensor network) or a *leaf* tensor.
21#[derive(Default, Debug, Clone, Serialize, Deserialize)]
22pub struct Tensor {
23 /// The inner tensors that make up this tensor. If non-empty, this tensor is
24 /// called a *composite* tensor.
25 pub(crate) tensors: Vec<Tensor>,
26
27 /// The legs of the tensor. Each leg should have a unique id. Connected tensors
28 /// are recognized by having at least one leg id in common.
29 pub(crate) legs: Vec<EdgeIndex>,
30
31 /// The bond dimensions of the legs, same length and order as `legs`. It is
32 /// assumed (but not checked!) that the bond dimensions of different tensors that
33 /// connect to the same leg match.
34 pub(crate) bond_dims: Vec<u64>,
35
36 /// The data of the tensor.
37 pub(crate) tensordata: TensorData,
38}
39
40impl Hash for Tensor {
41 fn hash<H: Hasher>(&self, state: &mut H) {
42 self.legs.hash(state);
43 }
44}
45
46impl Tensor {
47 /// Constructs a Tensor object with the given `legs` (edge ids) and corresponding
48 /// `bond_dims`. The tensor doesn't have underlying data.
49 #[inline]
50 pub(crate) fn new(legs: Vec<EdgeIndex>, bond_dims: Vec<u64>) -> Self {
51 assert_eq!(legs.len(), bond_dims.len());
52 Self {
53 legs,
54 bond_dims,
55 ..Default::default()
56 }
57 }
58
59 /// Constructs a Tensor using with the given edge ids and a mapping of edge ids
60 /// to corresponding bond dimension.
61 ///
62 /// # Examples
63 /// ```
64 /// # use tnc::tensornetwork::tensor::Tensor;
65 /// # use rustc_hash::FxHashMap;
66 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
67 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
68 /// assert_eq!(tensor.legs(), &[1, 2, 3]);
69 /// assert_eq!(tensor.bond_dims(), &[2, 4, 6]);
70 /// ```
71 #[inline]
72 pub fn new_from_map(legs: Vec<EdgeIndex>, bond_dims_map: &FxHashMap<EdgeIndex, u64>) -> Self {
73 let bond_dims = legs.iter().map(|l| bond_dims_map[l]).collect();
74 Self::new(legs, bond_dims)
75 }
76
77 /// Constructs a Tensor with the given edge ids and the same bond dimension for
78 /// all edges.
79 ///
80 /// # Examples
81 /// ```
82 /// # use tnc::tensornetwork::tensor::Tensor;
83 /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 2);
84 /// assert_eq!(tensor.legs(), &[1, 2, 3]);
85 /// assert_eq!(tensor.bond_dims(), &[2, 2, 2]);
86 /// ```
87 #[inline]
88 pub fn new_from_const(legs: Vec<EdgeIndex>, bond_dim: u64) -> Self {
89 let bond_dims = vec![bond_dim; legs.len()];
90 Self::new(legs, bond_dims)
91 }
92
93 /// Creates a new composite tensor with the given nested tensors.
94 #[inline]
95 pub fn new_composite(tensors: Vec<Self>) -> Self {
96 Self {
97 tensors,
98 ..Default::default()
99 }
100 }
101
102 /// Returns edge ids of Tensor object.
103 ///
104 /// # Examples
105 /// ```
106 /// # use tnc::tensornetwork::tensor::Tensor;
107 /// let tensor = Tensor::new_from_const(vec![1, 2, 3], 3);
108 /// assert_eq!(tensor.legs(), &[1, 2, 3]);
109 /// ```
110 #[inline]
111 pub fn legs(&self) -> &Vec<EdgeIndex> {
112 &self.legs
113 }
114
115 /// Internal method to set legs. Needs pub(crate) for contraction order finding for hierarchies.
116 #[inline]
117 pub(crate) fn set_legs(&mut self, legs: Vec<EdgeIndex>) {
118 self.legs = legs;
119 }
120
121 /// Returns an iterator of tuples of leg ids and their corresponding bond size.
122 #[inline]
123 pub fn edges(&self) -> impl Iterator<Item = (&EdgeIndex, &u64)> + '_ {
124 std::iter::zip(&self.legs, &self.bond_dims)
125 }
126
127 /// Returns the nested tensors of a composite tensor.
128 ///
129 /// # Examples
130 ///
131 /// ```
132 /// # use tnc::tensornetwork::tensor::Tensor;
133 /// # use rustc_hash::FxHashMap;
134 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
135 /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
136 /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
137 /// let tn = Tensor::new_composite(vec![v1.clone(), v2.clone()]);
138 /// for (tensor, ref_tensor) in std::iter::zip(tn.tensors(), vec![v1, v2]){
139 /// assert_eq!(tensor.legs(), ref_tensor.legs());
140 /// }
141 /// ```
142 #[inline]
143 pub fn tensors(&self) -> &Vec<Self> {
144 &self.tensors
145 }
146
147 /// Gets a nested `Tensor` based on the `nested_indices` which specify the index
148 /// of the tensor at each level of the hierarchy.
149 ///
150 /// # Examples
151 /// ```
152 /// # use tnc::tensornetwork::tensor::Tensor;
153 /// # use rustc_hash::FxHashMap;
154 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 2), (4, 1)]);
155 /// let mut v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
156 /// let mut v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
157 /// let mut v3 = Tensor::new_from_map(vec![2, 3], &bond_dims);
158 /// let mut v4 = Tensor::new_from_map(vec![3, 4], &bond_dims);
159 /// let tn1 = Tensor::new_composite(vec![v1, v2]);
160 /// let tn2 = Tensor::new_composite(vec![v3.clone(), v4]);
161 /// let nested_tn = Tensor::new_composite(vec![tn1, tn2]);
162 ///
163 /// assert_eq!(nested_tn.nested_tensor(&[1, 0]).legs(), v3.legs());
164 /// ```
165 pub fn nested_tensor(&self, nested_indices: &[usize]) -> &Tensor {
166 let mut tensor = self;
167 for index in nested_indices {
168 tensor = tensor.tensor(*index);
169 }
170 tensor
171 }
172
173 /// Returns the total number of leaf tensors in the hierarchy.
174 pub fn total_num_tensors(&self) -> usize {
175 if self.is_composite() {
176 self.tensors.iter().map(Self::total_num_tensors).sum()
177 } else {
178 1
179 }
180 }
181
182 /// Get ith Tensor.
183 ///
184 /// # Examples
185 ///
186 /// ```
187 /// # use tnc::tensornetwork::tensor::Tensor;
188 /// # use rustc_hash::FxHashMap;
189 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8)]);
190 /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
191 /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
192 /// let tn = Tensor::new_composite(vec![v1.clone(), v2]);
193 /// assert_eq!(tn.tensor(0).legs(), v1.legs());
194 /// ```
195 #[inline]
196 pub fn tensor(&self, i: TensorIndex) -> &Self {
197 &self.tensors[i]
198 }
199
200 /// Getter for bond dimensions.
201 #[inline]
202 pub fn bond_dims(&self) -> &Vec<u64> {
203 assert!(self.is_leaf());
204 &self.bond_dims
205 }
206
207 /// Returns the shape of tensor. This is the same as the bond dimensions, but as
208 /// `usize`. The conversion can fail, hence a [`Result`] is returned.
209 pub fn shape(&self) -> Result<Vec<usize>, TryFromIntError> {
210 self.bond_dims.iter().map(|&dim| dim.try_into()).collect()
211 }
212
213 /// Returns the number of dimensions.
214 ///
215 /// # Examples
216 /// ```
217 /// # use tnc::tensornetwork::tensor::Tensor;
218 /// # use rustc_hash::FxHashMap;
219 /// let bond_dims = FxHashMap::from_iter([(1, 4), (2, 6), (3, 2)]);
220 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
221 /// assert_eq!(tensor.dims(), 3);
222 /// ```
223 #[inline]
224 pub fn dims(&self) -> usize {
225 self.legs.len()
226 }
227
228 /// Returns the number of elements. This is a f64 to avoid overflow in large
229 /// tensors.
230 ///
231 /// # Examples
232 /// ```
233 /// # use tnc::tensornetwork::tensor::Tensor;
234 /// # use rustc_hash::FxHashMap;
235 /// let bond_dims = FxHashMap::from_iter([(1, 5), (2, 15), (3, 8)]);
236 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
237 /// assert_eq!(tensor.size(), 600.0);
238 /// ```
239 #[inline]
240 pub fn size(&self) -> f64 {
241 self.bond_dims.iter().map(|v| *v as f64).product()
242 }
243
244 /// Returns true if Tensor is a leaf tensor, without any nested tensors.
245 ///
246 /// # Examples
247 /// ```
248 /// # use tnc::tensornetwork::tensor::Tensor;
249 /// # use rustc_hash::FxHashMap;
250 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
251 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
252 /// assert_eq!(tensor.is_leaf(), true);
253 /// let comp = Tensor::new_composite(vec![tensor]);
254 /// assert_eq!(comp.is_leaf(), false);
255 /// ```
256 #[inline]
257 pub fn is_leaf(&self) -> bool {
258 self.tensors.is_empty()
259 }
260
261 /// Returns true if Tensor is composite.
262 ///
263 /// # Examples
264 /// ```
265 /// # use tnc::tensornetwork::tensor::Tensor;
266 /// # use rustc_hash::FxHashMap;
267 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6)]);
268 /// let tensor = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
269 /// assert_eq!(tensor.is_composite(), false);
270 /// let comp = Tensor::new_composite(vec![tensor]);
271 /// assert_eq!(comp.is_composite(), true);
272 /// ```
273 #[inline]
274 pub fn is_composite(&self) -> bool {
275 !self.tensors.is_empty()
276 }
277
278 /// Returns true if Tensor is empty. This means, it doesn't have any subtensors,
279 /// has no legs and is doesn't have any data (e.g., is not a scalar).
280 ///
281 /// # Examples
282 /// ```
283 /// # use tnc::tensornetwork::tensor::Tensor;
284 /// let tensor = Tensor::default();
285 /// assert_eq!(tensor.is_empty(), true);
286 /// ```
287 #[inline]
288 pub fn is_empty(&self) -> bool {
289 self.tensors.is_empty()
290 && self.legs.is_empty()
291 && matches!(*self.tensor_data(), TensorData::Uncontracted)
292 }
293
294 /// Pushes additional `tensor` into this tensor, which must be a composite tensor.
295 #[inline]
296 pub fn push_tensor(&mut self, tensor: Self) {
297 assert!(
298 self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
299 "Cannot push tensors into a leaf tensor"
300 );
301 self.tensors.push(tensor);
302 }
303
304 /// Pushes additional `tensors` into this tensor, which must be a composite tensor.
305 #[inline]
306 pub fn push_tensors(&mut self, mut tensors: Vec<Self>) {
307 assert!(
308 self.legs.is_empty() && matches!(self.tensordata, TensorData::Uncontracted),
309 "Cannot push tensors into a leaf tensor"
310 );
311 self.tensors.append(&mut tensors);
312 }
313
314 /// Getter for tensor data.
315 #[inline]
316 pub fn tensor_data(&self) -> &TensorData {
317 &self.tensordata
318 }
319
320 /// Setter for tensor data.
321 ///
322 /// # Examples
323 ///
324 /// ```
325 /// # use tnc::tensornetwork::tensor::Tensor;
326 /// # use tnc::tensornetwork::tensordata::TensorData;
327 /// let mut tensor = Tensor::new_from_const(vec![0, 1], 2);
328 /// let tensordata = TensorData::Gate((String::from("x"), vec![], false));
329 /// tensor.set_tensor_data(tensordata);
330 /// ```
331 #[inline]
332 pub fn set_tensor_data(&mut self, tensordata: TensorData) {
333 assert!(
334 self.is_leaf() || matches!(tensordata, TensorData::Uncontracted),
335 "Cannot add data to composite tensor"
336 );
337 self.tensordata = tensordata;
338 }
339
340 /// Returns whether all tensors inside this tensor are connected.
341 /// This only checks the top-level, not recursing into composite tensors.
342 ///
343 /// # Examples
344 /// ```
345 /// # use tnc::tensornetwork::tensor::Tensor;
346 /// # use rustc_hash::FxHashMap;
347 /// // Create a tensor network with two connected tensors
348 /// let bond_dims = FxHashMap::from_iter([(0, 17), (1, 19), (2, 8), (3, 5)]);
349 /// let v1 = Tensor::new_from_map(vec![0, 1], &bond_dims);
350 /// let v2 = Tensor::new_from_map(vec![1, 2], &bond_dims);
351 /// let mut tn = Tensor::new_composite(vec![v1, v2]);
352 /// assert!(tn.is_connected());
353 ///
354 /// // Introduce a new tensor that is not connected
355 /// let v3 = Tensor::new_from_map(vec![3], &bond_dims);
356 /// tn.push_tensor(v3);
357 /// assert!(!tn.is_connected());
358 /// ```
359 pub fn is_connected(&self) -> bool {
360 let num_tensors = self.tensors.len();
361 let mut uf = UnionFind::new(num_tensors);
362
363 for t1_id in 0..num_tensors - 1 {
364 for t2_id in (t1_id + 1)..num_tensors {
365 let t1 = &self.tensors[t1_id];
366 let t2 = &self.tensors[t2_id];
367 if !(t1 & t2).legs.is_empty() {
368 uf.union(t1_id, t2_id);
369 }
370 }
371 }
372
373 uf.count_sets() == 1
374 }
375
376 /// Returns `Tensor` with legs in `self` that are not in `other`.
377 ///
378 /// # Examples
379 /// ```
380 /// # use tnc::tensornetwork::tensor::Tensor;
381 /// # use rustc_hash::FxHashMap;
382 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
383 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
384 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
385 /// let diff_tensor = &tensor1 - &tensor2;
386 /// assert_eq!(diff_tensor.legs(), &[1, 3]);
387 /// assert_eq!(diff_tensor.bond_dims(), &[2, 6]);
388 /// ```
389 #[must_use]
390 pub fn difference(&self, other: &Self) -> Self {
391 let mut new_legs = Vec::with_capacity(self.legs.len());
392 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
393 for (leg, dim) in self.edges() {
394 if !other.legs.contains(leg) {
395 new_legs.push(*leg);
396 new_bond_dims.push(*dim);
397 }
398 }
399 Self::new(new_legs, new_bond_dims)
400 }
401
402 /// Returns `Tensor` with union of legs in both `self` and `other`.
403 ///
404 /// # Examples
405 /// ```
406 /// # use tnc::tensornetwork::tensor::Tensor;
407 /// # use rustc_hash::FxHashMap;
408 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
409 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
410 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
411 /// let union_tensor = &tensor1 | &tensor2;
412 /// assert_eq!(union_tensor.legs(), &[1, 2, 3, 4, 5]);
413 /// assert_eq!(union_tensor.bond_dims(), &[2, 4, 6, 3, 9]);
414 /// ```
415 #[must_use]
416 pub fn union(&self, other: &Self) -> Self {
417 let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
418 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
419 new_legs.extend_from_slice(&self.legs);
420 new_bond_dims.extend_from_slice(&self.bond_dims);
421 for (leg, dim) in other.edges() {
422 if !self.legs.contains(leg) {
423 new_legs.push(*leg);
424 new_bond_dims.push(*dim);
425 }
426 }
427 Self::new(new_legs, new_bond_dims)
428 }
429
430 /// Returns `Tensor` with intersection of legs in `self` and `other`.
431 ///
432 /// # Examples
433 /// ```
434 /// # use tnc::tensornetwork::tensor::Tensor;
435 /// # use rustc_hash::FxHashMap;
436 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
437 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
438 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
439 /// let intersection_tensor = &tensor1 & &tensor2;
440 /// assert_eq!(intersection_tensor.legs(), &[2]);
441 /// assert_eq!(intersection_tensor.bond_dims(), &[4]);
442 /// ```
443 #[must_use]
444 pub fn intersection(&self, other: &Self) -> Self {
445 let mut new_legs = Vec::with_capacity(self.legs.len().min(other.legs.len()));
446 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
447 for (leg, dim) in self.edges() {
448 if other.legs.contains(leg) {
449 new_legs.push(*leg);
450 new_bond_dims.push(*dim);
451 }
452 }
453 Self::new(new_legs, new_bond_dims)
454 }
455
456 /// Returns `Tensor` with symmetrical difference of legs in `self` and `other`.
457 ///
458 /// # Examples
459 /// ```
460 /// # use tnc::tensornetwork::tensor::Tensor;
461 /// # use rustc_hash::FxHashMap;
462 /// let bond_dims = FxHashMap::from_iter([(1, 2), (2, 4), (3, 6), (4, 3), (5, 9)]);
463 /// let tensor1 = Tensor::new_from_map(vec![1, 2, 3], &bond_dims);
464 /// let tensor2 = Tensor::new_from_map(vec![4, 2, 5], &bond_dims);
465 /// let sym_dif_tensor = &tensor1 ^ &tensor2;
466 /// assert_eq!(sym_dif_tensor.legs(), &[1, 3, 4, 5]);
467 /// assert_eq!(sym_dif_tensor.bond_dims(), &[2, 6, 3, 9]);
468 /// ```
469 #[must_use]
470 pub fn symmetric_difference(&self, other: &Self) -> Self {
471 let mut new_legs = Vec::with_capacity(self.legs.len() + other.legs.len());
472 let mut new_bond_dims = Vec::with_capacity(new_legs.capacity());
473 for (leg, dim) in self.edges() {
474 if !other.legs.contains(leg) {
475 new_legs.push(*leg);
476 new_bond_dims.push(*dim);
477 }
478 }
479 for (leg, dim) in other.edges() {
480 if !self.legs.contains(leg) {
481 new_legs.push(*leg);
482 new_bond_dims.push(*dim);
483 }
484 }
485 Self::new(new_legs, new_bond_dims)
486 }
487
488 /// Get output legs after tensor contraction
489 pub fn external_tensor(&self) -> Tensor {
490 if self.is_leaf() {
491 return self.clone();
492 }
493
494 let mut ext_tensor = Self::default();
495 for tensor in &self.tensors {
496 let new_tensor = if tensor.is_composite() {
497 &tensor.external_tensor()
498 } else {
499 tensor
500 };
501 ext_tensor = &ext_tensor ^ new_tensor;
502 }
503
504 ext_tensor
505 }
506}
507
508impl ApproxEq for &Tensor {
509 type Margin = F64Margin;
510
511 fn approx_eq<M: Into<Self::Margin>>(self, other: Self, margin: M) -> bool {
512 let margin = margin.into();
513 if self.legs != other.legs {
514 return false;
515 }
516 if self.bond_dims != other.bond_dims {
517 return false;
518 }
519 if self.tensors.len() != other.tensors.len() {
520 return false;
521 }
522 for (tensor, other_tensor) in zip(&self.tensors, &other.tensors) {
523 if !tensor.approx_eq(other_tensor, margin) {
524 return false;
525 }
526 }
527
528 self.tensordata.approx_eq(&other.tensordata, margin)
529 }
530}
531
532impl BitOr for &Tensor {
533 type Output = Tensor;
534 #[inline]
535 fn bitor(self, rhs: &Tensor) -> Tensor {
536 self.union(rhs)
537 }
538}
539
540impl BitAnd for &Tensor {
541 type Output = Tensor;
542 #[inline]
543 fn bitand(self, rhs: &Tensor) -> Tensor {
544 self.intersection(rhs)
545 }
546}
547
548impl BitXor for &Tensor {
549 type Output = Tensor;
550 #[inline]
551 fn bitxor(self, rhs: &Tensor) -> Tensor {
552 self.symmetric_difference(rhs)
553 }
554}
555
556impl Sub for &Tensor {
557 type Output = Tensor;
558 #[inline]
559 fn sub(self, rhs: &Tensor) -> Tensor {
560 self.difference(rhs)
561 }
562}
563
564impl BitXorAssign<&Tensor> for Tensor {
565 #[inline]
566 fn bitxor_assign(&mut self, rhs: &Tensor) {
567 *self = self.symmetric_difference(rhs);
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 use std::iter::zip;
576
577 use rustc_hash::FxHashMap;
578
579 use crate::tensornetwork::tensordata::TensorData;
580
581 macro_rules! assert_matches {
582 ($left:expr, $pattern:pat) => {
583 match $left {
584 $pattern => (),
585 _ => panic!(
586 "Expected pattern {} but got {:?}",
587 stringify!($pattern),
588 $left
589 ),
590 }
591 };
592 }
593
594 #[test]
595 fn test_empty_tensor() {
596 let tensor = Tensor::default();
597 assert!(tensor.tensors.is_empty());
598 assert!(tensor.legs.is_empty());
599 assert!(tensor.bond_dims.is_empty());
600 assert!(tensor.is_empty());
601 }
602
603 #[test]
604 fn test_new() {
605 let tensor = Tensor::new(vec![2, 4, 5], vec![4, 2, 6]);
606 assert_eq!(tensor.legs(), &[2, 4, 5]);
607 assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
608 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
609 }
610
611 #[test]
612 fn test_new_from_map() {
613 let bond_dims = FxHashMap::from_iter([(1, 1), (2, 4), (3, 7), (4, 2), (5, 6)]);
614 let tensor = Tensor::new_from_map(vec![2, 4, 5], &bond_dims);
615 assert_eq!(tensor.legs(), &[2, 4, 5]);
616 assert_eq!(tensor.bond_dims(), &[4, 2, 6]);
617 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
618 }
619
620 #[test]
621 fn test_new_from_const() {
622 let tensor = Tensor::new_from_const(vec![9, 2, 5, 1], 3);
623 assert_eq!(tensor.legs(), &[9, 2, 5, 1]);
624 assert_eq!(tensor.bond_dims(), &[3, 3, 3, 3]);
625 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
626 }
627
628 #[test]
629 fn test_external_tensor() {
630 let bond_dims = FxHashMap::from_iter([
631 (2, 2),
632 (3, 4),
633 (4, 6),
634 (5, 8),
635 (6, 10),
636 (7, 12),
637 (8, 14),
638 (9, 16),
639 ]);
640 let tensor_1 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
641 let tensor_2 = Tensor::new_from_map(vec![2, 3, 5], &bond_dims);
642 let tensor_12 = Tensor::new_composite(vec![tensor_1, tensor_2]);
643
644 let tensor_3 = Tensor::new_from_map(vec![6, 7, 8], &bond_dims);
645 let tensor_4 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
646 let tensor_34 = Tensor::new_composite(vec![tensor_3, tensor_4]);
647
648 let tensor_1234 = Tensor::new_composite(vec![tensor_12, tensor_34]);
649
650 let external = tensor_1234.external_tensor();
651 assert_eq!(external.legs(), &[4, 5, 7, 9]);
652 assert_eq!(external.bond_dims(), &[6, 8, 12, 16]);
653 }
654
655 #[test]
656 fn test_push_tensor() {
657 let bond_dims =
658 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
659 let ref_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
660 let ref_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
661
662 let mut tensor = Tensor::default();
663
664 // Push tensor 1
665 let tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
666 tensor.push_tensor(tensor_1);
667
668 for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1]) {
669 assert_eq!(sub_tensor.legs(), ref_tensor.legs());
670 assert_eq!(sub_tensor.bond_dims(), ref_tensor.bond_dims());
671 }
672
673 // Push tensor 2
674 let tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
675 tensor.push_tensor(tensor_2);
676
677 for (sub_tensor, ref_tensor) in zip(tensor.tensors(), [&ref_tensor_1, &ref_tensor_2]) {
678 assert_eq!(sub_tensor.legs(), ref_tensor.legs());
679 assert_eq!(sub_tensor.bond_dims(), ref_tensor.bond_dims());
680 }
681
682 // Test that other fields are unchanged
683 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
684 assert!(tensor.legs().is_empty());
685 }
686
687 #[test]
688 #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
689 fn test_push_tensor_to_leaf() {
690 let bond_dims =
691 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
692 let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
693 let pushed_tensor = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
694 leaf_tensor.push_tensor(pushed_tensor);
695 }
696
697 #[test]
698 fn test_push_tensors() {
699 let bond_dims =
700 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
701 let ref_tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
702 let ref_tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
703 let ref_tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
704
705 let tensor_1 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
706 let tensor_2 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
707 let tensor_3 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
708 let mut tensor = Tensor::default();
709 tensor.push_tensors(vec![tensor_1, tensor_2, tensor_3]);
710
711 assert_matches!(tensor.tensor_data(), TensorData::Uncontracted);
712
713 for (sub_tensor, ref_tensor) in zip(
714 tensor.tensors(),
715 &vec![ref_tensor_1, ref_tensor_2, ref_tensor_3],
716 ) {
717 assert_eq!(sub_tensor.legs(), ref_tensor.legs());
718 assert_eq!(sub_tensor.bond_dims(), ref_tensor.bond_dims());
719 }
720 }
721
722 #[test]
723 #[should_panic(expected = "Cannot push tensors into a leaf tensor")]
724 fn test_push_tensors_to_leaf() {
725 let bond_dims =
726 FxHashMap::from_iter([(2, 17), (3, 1), (4, 11), (8, 3), (9, 20), (7, 7), (10, 14)]);
727 let mut leaf_tensor = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
728 let pushed_tensor_1 = Tensor::new_from_map(vec![8, 4, 9], &bond_dims);
729 let pushed_tensor_2 = Tensor::new_from_map(vec![7, 10, 2], &bond_dims);
730
731 leaf_tensor.push_tensors(vec![pushed_tensor_1, pushed_tensor_2]);
732 }
733}