1use std::cell::Ref;
2use std::rc::Rc;
3
4use itertools::Itertools;
5use rustc_hash::FxHashMap;
6
7use crate::contractionpath::contraction_tree::node::{
8 child_node, parent_node, Node, NodeRef, WeakNodeRef,
9};
10use crate::contractionpath::paths::validate_path;
11use crate::contractionpath::{ContractionPath, SimplePath, SimplePathRef};
12use crate::tensornetwork::tensor::Tensor;
13
14pub mod balancing;
15mod node;
16mod utils;
17
18#[derive(Default, Debug, Clone)]
20pub struct ContractionTree {
21 nodes: FxHashMap<usize, NodeRef>,
22 partitions: FxHashMap<usize, Vec<usize>>,
23 root: WeakNodeRef,
24}
25
26impl ContractionTree {
27 pub fn node(&self, node_id: usize) -> Ref<'_, Node> {
29 let borrow = &self.nodes[&node_id];
30 borrow.as_ref().borrow()
31 }
32
33 pub fn root_id(&self) -> Option<usize> {
35 self.root.upgrade().map(|node| node.borrow().id())
36 }
37
38 pub const fn partitions(&self) -> &FxHashMap<usize, Vec<usize>> {
39 &self.partitions
40 }
41
42 fn from_contraction_path_recurse(
45 tensor: &Tensor,
46 path: &ContractionPath,
47 nodes: &mut FxHashMap<usize, NodeRef>,
48 partitions: &mut FxHashMap<usize, Vec<usize>>,
49 prefix: &[usize],
50 ) {
51 let mut scratch = FxHashMap::default();
52
53 for (path_id, path) in path.nested.iter().sorted_by_key(|&(path_id, _)| *path_id) {
55 let composite_tensor = tensor.tensor(*path_id);
56 let mut new_prefix = prefix.to_owned();
57 new_prefix.push(*path_id);
58 Self::from_contraction_path_recurse(
59 composite_tensor,
60 path,
61 nodes,
62 partitions,
63 &new_prefix,
64 );
65 scratch.insert(*path_id, Rc::clone(&nodes[&(nodes.len() - 1)]));
66 }
67
68 for (tensor_idx, tensor) in tensor.tensors().iter().enumerate() {
70 if tensor.is_leaf() {
71 let mut nested_tensor_idx = prefix.to_owned();
72 nested_tensor_idx.push(tensor_idx);
73 let new_node = child_node(nodes.len(), nested_tensor_idx);
74 scratch.insert(tensor_idx, Rc::clone(&new_node));
75 nodes.insert(nodes.len(), new_node);
76 }
77 }
78
79 for (i_path, j_path) in &path.toplevel {
81 let i = &scratch[i_path];
82 let j = &scratch[j_path];
83 let parent = parent_node(nodes.len(), i, j);
84
85 scratch.insert(*i_path, Rc::clone(&parent));
86 nodes.insert(nodes.len(), parent);
87 scratch.remove(j_path);
88 }
89 partitions
90 .entry(prefix.len())
91 .or_default()
92 .push(nodes.len() - 1);
93 }
94
95 #[must_use]
99 pub fn from_contraction_path(tensor: &Tensor, path: &ContractionPath) -> Self {
100 validate_path(path);
101 let mut nodes = FxHashMap::default();
102 let mut partitions = FxHashMap::default();
103 Self::from_contraction_path_recurse(tensor, path, &mut nodes, &mut partitions, &[]);
104 let root = Rc::downgrade(&nodes[&(nodes.len() - 1)]);
105 Self {
106 nodes,
107 partitions,
108 root,
109 }
110 }
111
112 fn leaf_ids_recurse(node: &Node, leaf_indices: &mut Vec<usize>) {
113 if node.is_leaf() {
114 leaf_indices.push(node.id());
115 } else {
116 Self::leaf_ids_recurse(&node.left_child().unwrap().as_ref().borrow(), leaf_indices);
117 Self::leaf_ids_recurse(&node.right_child().unwrap().as_ref().borrow(), leaf_indices);
118 }
119 }
120
121 pub fn leaf_ids(&self, node_id: usize) -> Vec<usize> {
123 let mut leaf_indices = Vec::new();
124 let node = self.node(node_id);
125 Self::leaf_ids_recurse(&node, &mut leaf_indices);
126 leaf_indices
127 }
128
129 fn remove_subtree_recurse(&mut self, node_id: usize) {
131 if self.node(node_id).is_leaf() {
132 let node = &self.nodes[&node_id];
134 if let Some(parent_id) = node.borrow().parent_id() {
135 if self.nodes.contains_key(&parent_id) {
136 self.nodes[&parent_id]
137 .borrow_mut()
138 .remove_child(node.borrow().id());
139 }
140 }
141 node.borrow_mut().remove_parent();
142 return;
143 }
144
145 let node = self.nodes.remove(&node_id).unwrap();
146 let node = node.borrow();
147
148 if let Some(id) = node.left_child_id() {
149 self.remove_subtree_recurse(id);
150 }
151 if let Some(id) = node.right_child_id() {
152 self.remove_subtree_recurse(id);
153 }
154 }
155
156 pub(crate) fn remove_subtree(&mut self, node_id: usize) {
158 self.remove_subtree_recurse(node_id);
159 }
160
161 pub(crate) fn add_path_as_subtree(
164 &mut self,
165 path: &ContractionPath,
166 parent_id: usize,
167 leaf_tensor_indices: &[usize],
168 ) -> usize {
169 validate_path(path);
170 assert!(self.nodes.contains_key(&parent_id));
171
172 let mut index = 0;
173 let mut scratch = FxHashMap::default();
175
176 for &tensor_index in leaf_tensor_indices {
178 scratch.insert(tensor_index, Rc::clone(&self.nodes[&tensor_index]));
179 }
180
181 assert!(
183 path.is_simple(),
184 "Constructor not implemented for nested Tensors"
185 );
186 for (i_path, j_path) in &path.toplevel {
187 index = self.next_id(index);
189 let i = &scratch[i_path];
190 let j = &scratch[j_path];
191
192 assert!(
194 i.borrow().parent_id().is_none(),
195 "Tensor {i_path} is already used in another contraction"
196 );
197 assert!(
198 j.borrow().parent_id().is_none(),
199 "Tensor {j_path} is already used in another contraction"
200 );
201
202 let parent = parent_node(index, i, j);
203 scratch.insert(*i_path, Rc::clone(&parent));
204 scratch.remove(j_path);
205 self.nodes.insert(index, parent);
207 }
208
209 let new_parent = &self.nodes[&parent_id];
211 new_parent
212 .borrow_mut()
213 .add_child(Rc::downgrade(&self.nodes[&index]));
214
215 let new_child = &self.nodes[&index];
216 new_child.borrow_mut().set_parent(Rc::downgrade(new_parent));
217
218 index
219 }
220
221 fn remove_communication_path(&mut self, partition_ids: &[usize]) {
222 for partition_id in partition_ids {
223 let mut parent_id = self.node(*partition_id).parent_id();
224 while let Some(tensor_id) = parent_id {
225 parent_id = self.node(tensor_id).parent_id();
226 self.nodes.remove(&tensor_id);
227 }
228 }
229 }
230
231 fn replace_communication_path(
232 &mut self,
233 partition_ids: Vec<usize>,
234 communication_path: SimplePathRef,
235 ) {
236 self.remove_communication_path(&partition_ids);
238
239 let mut communication_ids = partition_ids;
241 let mut next_id = self.next_id(0);
242 for (i, j) in communication_path {
243 let left_child = communication_ids[*i];
244 let right_child = communication_ids[*j];
245 let new_parent =
246 parent_node(next_id, &self.nodes[&left_child], &self.nodes[&right_child]);
247 self.nodes.insert(next_id, new_parent);
248
249 communication_ids[*i] = next_id;
250 next_id = self.next_id(next_id);
251 }
252
253 self.root = Rc::downgrade(self.nodes.iter().max_by_key(|entry| entry.0).unwrap().1);
255 }
256
257 fn tree_weights_recurse(
258 node: &Node,
259 tn: &Tensor,
260 weights: &mut FxHashMap<usize, f64>,
261 scratch: &mut FxHashMap<usize, Tensor>,
262 cost_function: fn(&Tensor, &Tensor) -> f64,
263 ) {
264 if node.is_leaf() {
265 let Some(tensor_index) = &node.tensor_index() else {
266 panic!("All leaf nodes should have a tensor index")
267 };
268 weights.insert(node.id(), 0f64);
269 scratch.insert(node.id(), tn.nested_tensor(tensor_index).clone());
270 return;
271 }
272
273 let left_child = &node.left_child().unwrap();
274 let right_child = &node.right_child().unwrap();
275 let left_ref = left_child.as_ref().borrow();
276 let right_ref = right_child.as_ref().borrow();
277
278 Self::tree_weights_recurse(&left_ref, tn, weights, scratch, cost_function);
280 Self::tree_weights_recurse(&right_ref, tn, weights, scratch, cost_function);
281
282 let t1 = &scratch[&left_ref.id()];
283 let t2 = &scratch[&right_ref.id()];
284
285 let cost = weights[&left_ref.id()] + weights[&right_ref.id()] + cost_function(t1, t2);
286
287 weights.insert(node.id(), cost);
288 scratch.insert(node.id(), t1 ^ t2);
289 }
290
291 pub fn tree_weights(
298 &self,
299 node_id: usize,
300 tn: &Tensor,
301 cost_function: fn(&Tensor, &Tensor) -> f64,
302 ) -> FxHashMap<usize, f64> {
303 let mut weights = FxHashMap::default();
304 let mut scratch = FxHashMap::default();
305 let node = self.node(node_id);
306 Self::tree_weights_recurse(&node, tn, &mut weights, &mut scratch, cost_function);
307 weights
308 }
309
310 fn to_contraction_path_recurse(node: &Node, path: &mut SimplePath, replace: bool) -> usize {
317 if node.is_leaf() {
318 return node.id();
319 }
320
321 let (Some(left_child), Some(right_child)) = (node.left_child(), node.right_child()) else {
323 panic!("All parents should have two children")
324 };
325
326 let mut t1_id =
328 Self::to_contraction_path_recurse(&left_child.as_ref().borrow(), path, replace);
329 let mut t2_id =
330 Self::to_contraction_path_recurse(&right_child.as_ref().borrow(), path, replace);
331 if t2_id < t1_id {
332 (t1_id, t2_id) = (t2_id, t1_id);
333 }
334
335 path.push((t1_id, t2_id));
337
338 if replace {
340 t1_id
341 } else {
342 node.id()
343 }
344 }
345
346 pub fn to_flat_contraction_path(&self, node_id: usize, replace: bool) -> SimplePath {
351 let node = self.node(node_id);
352 let mut path = Vec::new();
353 Self::to_contraction_path_recurse(&node, &mut path, replace);
354 path
355 }
356
357 fn next_id(&self, mut init: usize) -> usize {
358 while self.nodes.contains_key(&init) {
359 init += 1;
360 }
361 init
362 }
363
364 pub fn tensor(&self, node_id: usize, tensor: &Tensor) -> Tensor {
373 let leaf_nodes = self.leaf_ids(node_id);
374 let mut new_tensor = Tensor::default();
375
376 for leaf_id in leaf_nodes {
377 new_tensor = &new_tensor
378 ^ tensor.nested_tensor(self.node(leaf_id).tensor_index().as_ref().unwrap());
379 }
380 new_tensor
381 }
382}
383
384fn populate_subtree_tensor_map_recursive(
385 contraction_tree: &ContractionTree,
386 node_id: usize,
387 node_tensor_map: &mut FxHashMap<usize, Tensor>,
388 tensor_network: &Tensor,
389 height_limit: Option<usize>,
390) -> (Tensor, usize) {
391 let node = contraction_tree.node(node_id);
392
393 if node.is_leaf() {
394 let tensor_index = node.tensor_index().unwrap();
395 let t = tensor_network.nested_tensor(tensor_index);
396 node_tensor_map.insert(node.id(), t.clone());
397 (t.clone(), 0)
398 } else {
399 let (t1, new_height1) = populate_subtree_tensor_map_recursive(
400 contraction_tree,
401 node.left_child_id().unwrap(),
402 node_tensor_map,
403 tensor_network,
404 height_limit,
405 );
406 let (t2, new_height2) = populate_subtree_tensor_map_recursive(
407 contraction_tree,
408 node.right_child_id().unwrap(),
409 node_tensor_map,
410 tensor_network,
411 height_limit,
412 );
413 let t12 = &t1 ^ &t2;
414 if let Some(height_limit) = height_limit {
415 if new_height1 < height_limit && new_height2 < height_limit {
416 node_tensor_map.insert(node.id(), t12.clone());
417 }
418 } else {
419 node_tensor_map.insert(node.id(), t12.clone());
420 }
421
422 (t12, new_height1.max(new_height2) + 1)
423 }
424}
425
426fn populate_subtree_tensor_map(
439 contraction_tree: &ContractionTree,
440 node_id: usize,
441 tensor_network: &Tensor,
442 height_limit: Option<usize>,
443) -> FxHashMap<usize, Tensor> {
444 let mut node_tensor_map = FxHashMap::default();
445 let _ = populate_subtree_tensor_map_recursive(
446 contraction_tree,
447 node_id,
448 &mut node_tensor_map,
449 tensor_network,
450 height_limit,
451 );
452 node_tensor_map
453}
454
455fn populate_leaf_node_tensor_map(
465 contraction_tree: &ContractionTree,
466 node_id: usize,
467 tensor_network: &Tensor,
468) -> FxHashMap<usize, Tensor> {
469 let mut node_tensor_map = FxHashMap::default();
470 for leaf_node_id in contraction_tree.leaf_ids(node_id) {
471 node_tensor_map.insert(
472 leaf_node_id,
473 contraction_tree.tensor(leaf_node_id, tensor_network),
474 );
475 }
476 node_tensor_map
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 use std::cell::RefCell;
484 use std::iter::zip;
485 use std::rc::Weak;
486
487 use itertools::Itertools;
488
489 use crate::contractionpath::contraction_cost::contract_cost_tensors;
490 use crate::contractionpath::contraction_tree::node::{child_node, parent_node};
491 use crate::contractionpath::contraction_tree::{ContractionTree, Node};
492 use crate::contractionpath::ssa_replace_ordering;
493 use crate::path;
494 use crate::tensornetwork::tensor::{EdgeIndex, Tensor};
495
496 fn setup_simple() -> (Tensor, ContractionPath, FxHashMap<EdgeIndex, u64>) {
497 let bond_dims =
498 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
499 (
500 Tensor::new_composite(vec![
501 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
502 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
503 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
504 ]),
505 path![(0, 1), (2, 0)],
506 bond_dims,
507 )
508 }
509
510 fn setup_complex() -> (Tensor, ContractionPath, FxHashMap<EdgeIndex, u64>) {
511 let bond_dims = FxHashMap::from_iter([
512 (0, 27),
513 (1, 18),
514 (2, 12),
515 (3, 15),
516 (4, 5),
517 (5, 3),
518 (6, 18),
519 (7, 22),
520 (8, 45),
521 (9, 65),
522 (10, 5),
523 ]);
524 (
525 Tensor::new_composite(vec![
526 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
527 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
528 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
529 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
530 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
531 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
532 ]),
533 path![(1, 5), (0, 1), (3, 4), (2, 3), (0, 2)],
534 bond_dims,
535 )
536 }
537
538 fn setup_unbalanced() -> (Tensor, ContractionPath) {
539 let bond_dims = FxHashMap::from_iter([
540 (0, 27),
541 (1, 18),
542 (2, 12),
543 (3, 15),
544 (4, 5),
545 (5, 3),
546 (6, 18),
547 (7, 22),
548 (8, 45),
549 (9, 65),
550 (10, 5),
551 (11, 17),
552 ]);
553 (
554 Tensor::new_composite(vec![
555 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
556 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
557 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
558 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
559 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
560 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
561 ]),
562 path![(0, 1), (2, 0), (3, 2), (4, 3), (5, 4)],
563 )
564 }
565
566 fn setup_nested() -> (Tensor, ContractionPath) {
567 let bond_dims = FxHashMap::from_iter([
568 (0, 27),
569 (1, 18),
570 (2, 12),
571 (3, 15),
572 (4, 5),
573 (5, 3),
574 (6, 18),
575 (7, 22),
576 (8, 45),
577 (9, 65),
578 (10, 5),
579 (11, 17),
580 ]);
581
582 let t0 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
583 let t1 = Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims);
584 let t2 = Tensor::new_from_map(vec![4, 5, 6], &bond_dims);
585 let t3 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
586 let t4 = Tensor::new_from_map(vec![5, 1, 0], &bond_dims);
587 let t5 = Tensor::new_from_map(vec![10, 8, 9], &bond_dims);
588
589 let t01 = Tensor::new_composite(vec![t0, t1]);
590 let t23 = Tensor::new_composite(vec![t2, t3]);
591 let t45 = Tensor::new_composite(vec![t4, t5]);
592 let tensor_network = Tensor::new_composite(vec![t01, t23, t45]);
593 (
594 tensor_network,
595 path![{(0, [(0, 1)]), (1, [(0, 1)]), (2, [(0, 1)])}, (0, 1), (0, 2)],
596 )
597 }
598
599 fn setup_double_nested() -> (Tensor, ContractionPath) {
600 let bond_dims = FxHashMap::from_iter([
601 (0, 27),
602 (1, 18),
603 (2, 12),
604 (3, 15),
605 (4, 5),
606 (5, 3),
607 (6, 18),
608 (7, 22),
609 (8, 45),
610 (9, 65),
611 (10, 5),
612 (11, 17),
613 ]);
614
615 let t0 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
616 let t1 = Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims);
617 let t2 = Tensor::new_from_map(vec![4, 5, 6], &bond_dims);
618 let t3 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
619 let t4 = Tensor::new_from_map(vec![5, 1, 0], &bond_dims);
620 let t5 = Tensor::new_from_map(vec![10, 8, 9], &bond_dims);
621
622 let t01 = Tensor::new_composite(vec![t0, t1]);
623 let t012 = Tensor::new_composite(vec![t01, t2]);
624 let t34 = Tensor::new_composite(vec![t3, t4]);
625 let t345 = Tensor::new_composite(vec![t34, t5]);
626 let tensor_network = Tensor::new_composite(vec![t012, t345]);
627 (
628 tensor_network,
629 path![
630 {
631 (0, [{(0, [(0, 1)])}, (0, 1)]),
632 (1, [{(0, [(0, 1)])}, (0, 1)]),
633 },
634 (0, 1)
635 ],
636 )
637 }
638
639 impl PartialEq for Node {
640 fn eq(&self, other: &Self) -> bool {
641 self.id() == other.id()
642 && self.left_child_id() == other.left_child_id()
643 && self.right_child_id() == other.right_child_id()
644 && self.parent_id() == other.parent_id()
645 && self.tensor_index() == other.tensor_index()
646 }
647 }
648
649 #[test]
650 fn test_from_contraction_path_simple() {
651 let (tensor, path, _) = setup_simple();
652 let ContractionTree { nodes, root, .. } =
653 ContractionTree::from_contraction_path(&tensor, &path);
654
655 let node0 = child_node(0, vec![0]);
656 let node1 = child_node(1, vec![1]);
657 let node2 = child_node(2, vec![2]);
658
659 let node3 = parent_node(3, &node0, &node1);
660 let node4 = parent_node(4, &node2, &node3);
661
662 let ref_root = Rc::clone(&node4);
663 let ref_nodes = [node0, node1, node2, node3, node4];
664
665 for (key, ref_node) in ref_nodes.iter().enumerate() {
666 let node = &nodes[&key];
667 assert_eq!(node, ref_node);
668 }
669 assert_eq!(root.upgrade().unwrap(), ref_root);
670 }
671
672 #[test]
673 fn test_from_contraction_path_complex() {
674 let (tensor, path, _) = setup_complex();
675 let ContractionTree { nodes, root, .. } =
676 ContractionTree::from_contraction_path(&tensor, &path);
677
678 let node0 = child_node(0, vec![0]);
679 let node1 = child_node(1, vec![1]);
680 let node2 = child_node(2, vec![2]);
681 let node3 = child_node(3, vec![3]);
682 let node4 = child_node(4, vec![4]);
683 let node5 = child_node(5, vec![5]);
684
685 let node6 = parent_node(6, &node1, &node5);
686 let node7 = parent_node(7, &node0, &node6);
687 let node8 = parent_node(8, &node3, &node4);
688 let node9 = parent_node(9, &node2, &node8);
689 let node10 = parent_node(10, &node7, &node9);
690
691 let ref_root = Rc::clone(&node10);
692 let ref_nodes = [
693 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
694 ];
695
696 for (key, ref_node) in ref_nodes.iter().enumerate() {
697 let node = &nodes[&key];
698 assert_eq!(node, ref_node);
699 }
700 assert_eq!(root.upgrade().unwrap(), ref_root);
701 }
702
703 #[test]
704 fn test_from_contraction_path_nested() {
705 let (tensor, path) = setup_nested();
706 let ContractionTree { nodes, root, .. } =
707 ContractionTree::from_contraction_path(&tensor, &path);
708
709 let node0 = child_node(0, vec![0, 0]);
710 let node1 = child_node(1, vec![0, 1]);
711 let node3 = child_node(3, vec![1, 0]);
712 let node4 = child_node(4, vec![1, 1]);
713 let node6 = child_node(6, vec![2, 0]);
714 let node7 = child_node(7, vec![2, 1]);
715
716 let node2 = parent_node(2, &node0, &node1);
717 let node5 = parent_node(5, &node3, &node4);
718 let node8 = parent_node(8, &node6, &node7);
719 let node9 = parent_node(9, &node2, &node5);
720 let node10 = parent_node(10, &node9, &node8);
721
722 let ref_root = Rc::clone(&node10);
723 let ref_nodes = [
724 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
725 ];
726
727 for (key, ref_node) in ref_nodes.iter().enumerate() {
728 let node = &nodes[&key];
729 assert_eq!(node, ref_node);
730 }
731 assert_eq!(root.upgrade().unwrap(), ref_root);
732 }
733
734 #[test]
735 fn test_from_contraction_path_double_nested() {
736 let (tensor, path) = setup_double_nested();
737 let ContractionTree { nodes, root, .. } =
738 ContractionTree::from_contraction_path(&tensor, &path);
739
740 let node0 = child_node(0, vec![0, 0, 0]);
741 let node1 = child_node(1, vec![0, 0, 1]);
742 let node3 = child_node(3, vec![0, 1]);
743 let node5 = child_node(5, vec![1, 0, 0]);
744 let node6 = child_node(6, vec![1, 0, 1]);
745 let node8 = child_node(8, vec![1, 1]);
746
747 let node2 = parent_node(2, &node0, &node1);
748 let node4 = parent_node(4, &node2, &node3);
749 let node7 = parent_node(7, &node5, &node6);
750 let node9 = parent_node(9, &node7, &node8);
751 let node10 = parent_node(10, &node4, &node9);
752
753 let ref_root = Rc::clone(&node10);
754 let ref_nodes = [
755 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
756 ];
757
758 for (key, ref_node) in ref_nodes.iter().enumerate() {
759 let node = &nodes[&key];
760 assert_eq!(node, ref_node);
761 }
762 assert_eq!(root.upgrade().unwrap(), ref_root);
763 }
764
765 #[test]
766 fn test_leaf_ids_simple() {
767 let (tn, path, _) = setup_simple();
768 let tree = ContractionTree::from_contraction_path(&tn, &path);
769
770 assert_eq!(tree.leaf_ids(4), vec![2, 0, 1]);
771 assert_eq!(tree.leaf_ids(3), vec![0, 1]);
772 assert_eq!(tree.leaf_ids(2), vec![2]);
773 }
774
775 #[test]
776 fn test_leaf_ids_complex() {
777 let (tn, path, _) = setup_complex();
778 let tree = ContractionTree::from_contraction_path(&tn, &path);
779
780 assert_eq!(tree.leaf_ids(10), vec![0, 1, 5, 2, 3, 4]);
781 assert_eq!(tree.leaf_ids(9), vec![2, 3, 4]);
782 assert_eq!(tree.leaf_ids(8), vec![3, 4]);
783 assert_eq!(tree.leaf_ids(7), vec![0, 1, 5]);
784 assert_eq!(tree.leaf_ids(6), vec![1, 5]);
785 assert_eq!(tree.leaf_ids(3), vec![3]);
786 }
787
788 #[test]
789 fn test_leaf_ids_nested() {
790 let (tn, path) = setup_nested();
791 let tree = ContractionTree::from_contraction_path(&tn, &path);
792 assert_eq!(tree.leaf_ids(10), vec![0, 1, 3, 4, 6, 7]);
793 assert_eq!(tree.leaf_ids(9), vec![0, 1, 3, 4]);
794 assert_eq!(tree.leaf_ids(8), vec![6, 7]);
795 assert_eq!(tree.leaf_ids(5), vec![3, 4]);
796 assert_eq!(tree.leaf_ids(2), vec![0, 1]);
797 }
798
799 #[test]
800 fn test_leaf_ids_double_nested() {
801 let (tn, path) = setup_double_nested();
802 let tree = ContractionTree::from_contraction_path(&tn, &path);
803
804 assert_eq!(tree.leaf_ids(10), vec![0, 1, 3, 5, 6, 8]);
805 assert_eq!(tree.leaf_ids(9), vec![5, 6, 8]);
806 assert_eq!(tree.leaf_ids(7), vec![5, 6]);
807 assert_eq!(tree.leaf_ids(4), vec![0, 1, 3]);
808 assert_eq!(tree.leaf_ids(2), vec![0, 1]);
809 }
810
811 #[test]
812 fn test_remove_subtree() {
813 let (tn, path) = setup_nested();
814 let mut tree = ContractionTree::from_contraction_path(&tn, &path);
815
816 tree.remove_subtree(8);
817
818 let ContractionTree { nodes, root, .. } = tree;
819
820 let node0 = child_node(0, vec![0, 0]);
821 let node1 = child_node(1, vec![0, 1]);
822 let node3 = child_node(3, vec![1, 0]);
823 let node4 = child_node(4, vec![1, 1]);
824 let node6 = child_node(6, vec![2, 0]);
825 let node7 = child_node(7, vec![2, 1]);
826 let node2 = parent_node(2, &node0, &node1);
827 let node5 = parent_node(5, &node3, &node4);
828 let node9 = parent_node(9, &node2, &node5);
829 let node10 = Rc::new(RefCell::new(Node::new(
830 10,
831 Rc::downgrade(&node9),
832 Weak::new(),
833 Weak::new(),
834 None,
835 )));
836 node9.borrow_mut().set_parent(Rc::downgrade(&node10));
837
838 let ref_root = Rc::clone(&node10);
839 let ref_nodes = [
840 node0, node1, node2, node3, node4, node5, node6, node7, node9, node10,
841 ];
842 let mut range = (0..8).collect_vec();
843 range.extend(9..11);
844 for (key, ref_node) in zip(range.iter(), ref_nodes.iter()) {
845 let node = &nodes[key];
846 assert_eq!(node, ref_node);
847 }
848 assert_eq!(root.upgrade().unwrap(), ref_root);
849 }
850
851 #[test]
852 fn test_remove_trivial_subtree() {
853 let (tensor, path) = setup_nested();
854 let mut tree = ContractionTree::from_contraction_path(&tensor, &path);
855
856 tree.remove_subtree(7);
857
858 let ContractionTree { nodes, root, .. } = tree;
859
860 let node0 = child_node(0, vec![0, 0]);
861 let node1 = child_node(1, vec![0, 1]);
862 let node3 = child_node(3, vec![1, 0]);
863 let node4 = child_node(4, vec![1, 1]);
864 let node6 = child_node(6, vec![2, 0]);
865 let node7 = child_node(7, vec![2, 1]);
866 let node2 = parent_node(2, &node0, &node1);
867 let node5 = parent_node(5, &node3, &node4);
868 let node8 = Rc::new(RefCell::new(Node::new(
869 8,
870 Rc::downgrade(&node6),
871 Weak::new(),
872 Weak::new(),
873 None,
874 )));
875 let node9 = parent_node(9, &node2, &node5);
876 let node10 = parent_node(10, &node9, &node8);
877 node6.borrow_mut().set_parent(Rc::downgrade(&node8));
878
879 let ref_root = Rc::clone(&node10);
880 let ref_nodes = [
881 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
882 ];
883
884 for (key, ref_node) in ref_nodes.iter().enumerate() {
885 let node = &nodes[&key];
886 assert_eq!(node, ref_node);
887 }
888 assert_eq!(root.upgrade().unwrap(), ref_root);
889 }
890
891 #[test]
892 fn test_tree_weights_simple() {
893 let (tensor, path, _) = setup_simple();
894 let tree = ContractionTree::from_contraction_path(&tensor, &path);
895 let ref_weights = FxHashMap::from_iter([(1, 0.), (0, 0.), (2, 0.), (3, 3820.), (4, 4540.)]);
896 let weights = tree.tree_weights(4, &tensor, contract_cost_tensors);
897
898 assert_eq!(weights, ref_weights);
899 let ref_weights = FxHashMap::from_iter([(1, 0.), (0, 0.), (3, 3820.)]);
900 let weights = tree.tree_weights(3, &tensor, contract_cost_tensors);
901 assert_eq!(weights, ref_weights);
902
903 assert_eq!(weights, ref_weights);
904 let ref_weights = FxHashMap::from_iter([(2, 0.)]);
905 let weights = tree.tree_weights(2, &tensor, contract_cost_tensors);
906 assert_eq!(weights, ref_weights);
907 }
908
909 #[test]
910 fn test_tree_weights_complex() {
911 let (tensor, path, _) = setup_complex();
912 let tree = ContractionTree::from_contraction_path(&tensor, &path);
913 let ref_weights = FxHashMap::from_iter([
914 (0, 0.),
915 (1, 0.),
916 (2, 0.),
917 (3, 0.),
918 (4, 0.),
919 (5, 0.),
920 (6, 2098440.),
921 (7, 2120010.),
922 (8, 2105820.),
923 (9, 2116470.),
924 (10, 4237070.),
925 ]);
926 let weights = tree.tree_weights(10, &tensor, contract_cost_tensors);
927
928 assert_eq!(weights, ref_weights);
929 }
930
931 #[test]
932 fn test_to_contraction_path_simple() {
933 let (tensor, ref_path, _) = setup_simple();
934 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
935 let path = tree.to_flat_contraction_path(4, false);
936 let path = ssa_replace_ordering(&ContractionPath::simple(path));
937 assert_eq!(path, ref_path);
938 }
939
940 #[test]
941 fn test_to_contraction_path_complex() {
942 let (tensor, ref_path, _) = setup_complex();
943 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
944 let path = tree.to_flat_contraction_path(10, false);
945 let path = ssa_replace_ordering(&ContractionPath::simple(path));
946 assert_eq!(path, ref_path);
947 }
948
949 #[test]
950 fn test_to_contraction_path_unbalanced() {
951 let (tensor, ref_path) = setup_unbalanced();
952 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
953 let path = tree.to_flat_contraction_path(10, false);
954 let path = ssa_replace_ordering(&ContractionPath::simple(path));
955 assert_eq!(path, ref_path);
956 }
957
958 #[test]
959 fn test_populate_subtree_tensor_map_simple() {
960 let (tensor, ref_path, bond_dims) = setup_simple();
961 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
962 let mut node_tensor_map = FxHashMap::default();
963 populate_subtree_tensor_map_recursive(&tree, 4, &mut node_tensor_map, &tensor, None);
964
965 let ref_node_tensor_map = FxHashMap::from_iter([
966 (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
967 (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
968 (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
969 (3, Tensor::new_from_map(vec![4, 0, 1], &bond_dims)),
970 (4, Tensor::new_from_map(vec![5, 6, 0, 1], &bond_dims)),
971 ]);
972
973 for (key, value) in ref_node_tensor_map {
974 assert_eq!(node_tensor_map[&key].legs(), value.legs());
975 }
976 }
977
978 #[test]
979 fn test_populate_subtree_tensor_map_complex() {
980 let (tensor, ref_path, bond_dims) = setup_complex();
981 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
982 let mut node_tensor_map = FxHashMap::default();
983 populate_subtree_tensor_map_recursive(&tree, 10, &mut node_tensor_map, &tensor, None);
984
985 let ref_node_tensor_map = FxHashMap::from_iter([
986 (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
987 (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
988 (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
989 (3, Tensor::new_from_map(vec![6, 8, 9], &bond_dims)),
990 (4, Tensor::new_from_map(vec![10, 8, 9], &bond_dims)),
991 (5, Tensor::new_from_map(vec![5, 1, 0], &bond_dims)),
992 (6, Tensor::new_from_map(vec![3, 2, 5], &bond_dims)),
993 (7, Tensor::new_from_map(vec![4, 5], &bond_dims)),
994 (8, Tensor::new_from_map(vec![6, 10], &bond_dims)),
995 (9, Tensor::new_from_map(vec![4, 5, 10], &bond_dims)),
996 (10, Tensor::new_from_map(vec![10], &bond_dims)),
997 ]);
998
999 for (key, value) in ref_node_tensor_map {
1000 assert_eq!(node_tensor_map[&key].legs(), value.legs());
1001 }
1002 }
1003
1004 #[test]
1005 fn test_populate_subtree_tensor_map_height_limit() {
1006 let (tensor, ref_path, bond_dims) = setup_complex();
1007 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
1008 let node_tensor_map = populate_subtree_tensor_map(&tree, 10, &tensor, Some(1));
1009
1010 let ref_node_tensor_map = FxHashMap::from_iter([
1011 (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
1012 (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
1013 (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
1014 (3, Tensor::new_from_map(vec![6, 8, 9], &bond_dims)),
1015 (4, Tensor::new_from_map(vec![10, 8, 9], &bond_dims)),
1016 (5, Tensor::new_from_map(vec![5, 1, 0], &bond_dims)),
1017 (6, Tensor::new_from_map(vec![3, 2, 5], &bond_dims)),
1018 (8, Tensor::new_from_map(vec![6, 10], &bond_dims)),
1019 ]);
1020
1021 for (key, value) in ref_node_tensor_map {
1022 assert_eq!(node_tensor_map[&key].legs(), value.legs());
1023 }
1024 }
1025
1026 #[test]
1027 fn test_populate_leaf_node_tensor_map_simple() {
1028 let (tensor, ref_path, bond_dims) = setup_simple();
1029 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
1030
1031 let node_tensor_map = populate_leaf_node_tensor_map(&tree, 4, &tensor);
1032
1033 let ref_node_tensor_map = FxHashMap::from_iter([
1034 (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
1035 (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
1036 (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
1037 ]);
1038
1039 for (key, value) in ref_node_tensor_map {
1040 assert_eq!(node_tensor_map[&key].legs(), value.legs());
1041 }
1042 }
1043
1044 #[test]
1045 fn test_populate_leaf_node_tensor_map_complex() {
1046 let (tensor, ref_path, bond_dims) = setup_complex();
1047 let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
1048 let node_tensor_map = populate_subtree_tensor_map(&tree, 10, &tensor, None);
1049
1050 let ref_node_tensor_map = FxHashMap::from_iter([
1051 (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
1052 (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
1053 (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
1054 (3, Tensor::new_from_map(vec![6, 8, 9], &bond_dims)),
1055 (4, Tensor::new_from_map(vec![10, 8, 9], &bond_dims)),
1056 ]);
1057
1058 for (key, value) in ref_node_tensor_map {
1059 assert_eq!(node_tensor_map[&key].legs(), value.legs());
1060 }
1061 }
1062
1063 #[test]
1064 fn test_add_path_as_subtree() {
1065 let (tensor, path, _) = setup_complex();
1066
1067 let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1068 complex_tree.remove_subtree(9);
1069 let new_path = path![(4, 2), (4, 3)];
1070
1071 complex_tree.add_path_as_subtree(&new_path, 10, &[3, 4, 2]);
1072
1073 let ContractionTree { nodes, root, .. } = complex_tree;
1074
1075 let node0 = child_node(0, vec![0]);
1076 let node1 = child_node(1, vec![1]);
1077 let node2 = child_node(2, vec![2]);
1078 let node3 = child_node(3, vec![3]);
1079 let node4 = child_node(4, vec![4]);
1080 let node5 = child_node(5, vec![5]);
1081 let node6 = parent_node(6, &node1, &node5);
1082 let node7 = parent_node(7, &node0, &node6);
1083 let node8 = parent_node(8, &node4, &node2);
1084 let node9 = parent_node(9, &node8, &node3);
1085 let node10 = parent_node(10, &node7, &node9);
1086
1087 let ref_root = Rc::clone(&node10);
1088 let ref_nodes = [
1089 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
1090 ];
1091
1092 for (key, ref_node) in ref_nodes.iter().enumerate() {
1093 let node = &nodes[&key];
1094 assert_eq!(node, ref_node);
1095 }
1096 assert_eq!(root.upgrade().unwrap(), ref_root);
1097 }
1098
1099 #[test]
1100 #[should_panic = "Tensor 2 is already used in another contraction"]
1101 fn test_add_path_as_subtree_invalid_path() {
1102 let (tensor, path, _) = setup_complex();
1103
1104 let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1105 complex_tree.remove_subtree(8);
1106 let new_path = path![(4, 2), (4, 3)];
1107
1108 complex_tree.add_path_as_subtree(&new_path, 9, &[3, 4, 2]);
1109 }
1110
1111 #[test]
1112 fn test_remove_communication_path() {
1113 let (tensor, path) = setup_nested();
1114 let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1115 let partition_ids = vec![2, 5, 8];
1116 complex_tree.remove_communication_path(&partition_ids);
1117 assert!(!complex_tree.nodes.contains_key(&9));
1118 assert!(!complex_tree.nodes.contains_key(&10));
1119 assert!(complex_tree.root_id().is_none());
1120 }
1121
1122 #[test]
1123 fn test_replace_communication_path() {
1124 let (tensor, path) = setup_nested();
1125 let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1126 let partition_ids = vec![2, 5, 8];
1127 complex_tree.replace_communication_path(partition_ids, &[(0, 2), (1, 0)]);
1128
1129 let ContractionTree { nodes, root, .. } = complex_tree;
1130
1131 let node2 = child_node(2, vec![]);
1132 let node5 = child_node(5, vec![]);
1133 let node8 = child_node(8, vec![]);
1134 let node9 = parent_node(9, &node2, &node8);
1135 let node10 = parent_node(10, &node5, &node9);
1136
1137 assert_eq!(nodes[&9], node9);
1138 assert_eq!(nodes[&10], node10);
1139 assert_eq!(root.upgrade().unwrap(), node10);
1140 }
1141}