tnc/contractionpath/
contraction_tree.rs

1use std::cell::Ref;
2use std::rc::Rc;
3
4use itertools::Itertools;
5use rustc_hash::FxHashMap;
6
7use crate::contractionpath::contraction_tree::node::{
8    child_node, parent_node, Node, NodeRef, WeakNodeRef,
9};
10use crate::contractionpath::paths::validate_path;
11use crate::contractionpath::{ContractionPath, SimplePath, SimplePathRef};
12use crate::tensornetwork::tensor::Tensor;
13
14pub mod balancing;
15mod node;
16mod utils;
17
18/// Struct representing the full contraction path of a given [`Tensor`] object.
19#[derive(Default, Debug, Clone)]
20pub struct ContractionTree {
21    nodes: FxHashMap<usize, NodeRef>,
22    partitions: FxHashMap<usize, Vec<usize>>,
23    root: WeakNodeRef,
24}
25
26impl ContractionTree {
27    /// Returns a reference to the node with the given `node_id`.
28    pub fn node(&self, node_id: usize) -> Ref<'_, Node> {
29        let borrow = &self.nodes[&node_id];
30        borrow.as_ref().borrow()
31    }
32
33    /// Returns the node id of the root node, if any.
34    pub fn root_id(&self) -> Option<usize> {
35        self.root.upgrade().map(|node| node.borrow().id())
36    }
37
38    pub const fn partitions(&self) -> &FxHashMap<usize, Vec<usize>> {
39        &self.partitions
40    }
41
42    /// Populates `nodes` and `partitions` with the tree structure of the contraction
43    /// `path`.
44    fn from_contraction_path_recurse(
45        tensor: &Tensor,
46        path: &ContractionPath,
47        nodes: &mut FxHashMap<usize, NodeRef>,
48        partitions: &mut FxHashMap<usize, Vec<usize>>,
49        prefix: &[usize],
50    ) {
51        let mut scratch = FxHashMap::default();
52
53        // Obtain tree structure from composite tensors
54        for (path_id, path) in path.nested.iter().sorted_by_key(|&(path_id, _)| *path_id) {
55            let composite_tensor = tensor.tensor(*path_id);
56            let mut new_prefix = prefix.to_owned();
57            new_prefix.push(*path_id);
58            Self::from_contraction_path_recurse(
59                composite_tensor,
60                path,
61                nodes,
62                partitions,
63                &new_prefix,
64            );
65            scratch.insert(*path_id, Rc::clone(&nodes[&(nodes.len() - 1)]));
66        }
67
68        // Add nodes for leaf tensors
69        for (tensor_idx, tensor) in tensor.tensors().iter().enumerate() {
70            if tensor.is_leaf() {
71                let mut nested_tensor_idx = prefix.to_owned();
72                nested_tensor_idx.push(tensor_idx);
73                let new_node = child_node(nodes.len(), nested_tensor_idx);
74                scratch.insert(tensor_idx, Rc::clone(&new_node));
75                nodes.insert(nodes.len(), new_node);
76            }
77        }
78
79        // Build tree based on contraction path
80        for (i_path, j_path) in &path.toplevel {
81            let i = &scratch[i_path];
82            let j = &scratch[j_path];
83            let parent = parent_node(nodes.len(), i, j);
84
85            scratch.insert(*i_path, Rc::clone(&parent));
86            nodes.insert(nodes.len(), parent);
87            scratch.remove(j_path);
88        }
89        partitions
90            .entry(prefix.len())
91            .or_default()
92            .push(nodes.len() - 1);
93    }
94
95    /// Creates a `ContractionTree` from `tensor` and contract `path`. The tree
96    /// represents all intermediate tensors and costs of given contraction path and
97    /// tensor network.
98    #[must_use]
99    pub fn from_contraction_path(tensor: &Tensor, path: &ContractionPath) -> Self {
100        validate_path(path);
101        let mut nodes = FxHashMap::default();
102        let mut partitions = FxHashMap::default();
103        Self::from_contraction_path_recurse(tensor, path, &mut nodes, &mut partitions, &[]);
104        let root = Rc::downgrade(&nodes[&(nodes.len() - 1)]);
105        Self {
106            nodes,
107            partitions,
108            root,
109        }
110    }
111
112    fn leaf_ids_recurse(node: &Node, leaf_indices: &mut Vec<usize>) {
113        if node.is_leaf() {
114            leaf_indices.push(node.id());
115        } else {
116            Self::leaf_ids_recurse(&node.left_child().unwrap().as_ref().borrow(), leaf_indices);
117            Self::leaf_ids_recurse(&node.right_child().unwrap().as_ref().borrow(), leaf_indices);
118        }
119    }
120
121    /// Returns the id of all leaf nodes in subtree with root at `node_id`.
122    pub fn leaf_ids(&self, node_id: usize) -> Vec<usize> {
123        let mut leaf_indices = Vec::new();
124        let node = self.node(node_id);
125        Self::leaf_ids_recurse(&node, &mut leaf_indices);
126        leaf_indices
127    }
128
129    /// Removes subtree with root at `node_id`.
130    fn remove_subtree_recurse(&mut self, node_id: usize) {
131        if self.node(node_id).is_leaf() {
132            // Leaf nodes are not removed. We need to manually clear parent/children relations
133            let node = &self.nodes[&node_id];
134            if let Some(parent_id) = node.borrow().parent_id() {
135                if self.nodes.contains_key(&parent_id) {
136                    self.nodes[&parent_id]
137                        .borrow_mut()
138                        .remove_child(node.borrow().id());
139                }
140            }
141            node.borrow_mut().remove_parent();
142            return;
143        }
144
145        let node = self.nodes.remove(&node_id).unwrap();
146        let node = node.borrow();
147
148        if let Some(id) = node.left_child_id() {
149            self.remove_subtree_recurse(id);
150        }
151        if let Some(id) = node.right_child_id() {
152            self.remove_subtree_recurse(id);
153        }
154    }
155
156    /// Removes subtree with root at `node_id`.
157    pub(crate) fn remove_subtree(&mut self, node_id: usize) {
158        self.remove_subtree_recurse(node_id);
159    }
160
161    /// Converts a contraction path into a ContractionTree, then attaches this as a subtree at `parent_id`
162    /// The ContractionTree should already contain the leaf nodes of the
163    pub(crate) fn add_path_as_subtree(
164        &mut self,
165        path: &ContractionPath,
166        parent_id: usize,
167        leaf_tensor_indices: &[usize],
168    ) -> usize {
169        validate_path(path);
170        assert!(self.nodes.contains_key(&parent_id));
171
172        let mut index = 0;
173        // Utilize a scratch hashmap to store intermediate tensor information
174        let mut scratch = FxHashMap::default();
175
176        // Fill scratch with leaf tensors, these should already be present in self.nodes.
177        for &tensor_index in leaf_tensor_indices {
178            scratch.insert(tensor_index, Rc::clone(&self.nodes[&tensor_index]));
179        }
180
181        // Generate intermediate tensors by looping over contraction operations, fill and update scratch as needed.
182        assert!(
183            path.is_simple(),
184            "Constructor not implemented for nested Tensors"
185        );
186        for (i_path, j_path) in &path.toplevel {
187            // Always keep track of latest added tensor. Last index will be the root of the subtree.
188            index = self.next_id(index);
189            let i = &scratch[i_path];
190            let j = &scratch[j_path];
191
192            // Ensure that we are not reusing nodes that are already in another contraction path
193            assert!(
194                i.borrow().parent_id().is_none(),
195                "Tensor {i_path} is already used in another contraction"
196            );
197            assert!(
198                j.borrow().parent_id().is_none(),
199                "Tensor {j_path} is already used in another contraction"
200            );
201
202            let parent = parent_node(index, i, j);
203            scratch.insert(*i_path, Rc::clone(&parent));
204            scratch.remove(j_path);
205            // Ensure that intermediate tensor information is stored in internal HashMap for reference
206            self.nodes.insert(index, parent);
207        }
208
209        // Add the root of the subtree to the indicated node `parent_id` in larger contraction tree.
210        let new_parent = &self.nodes[&parent_id];
211        new_parent
212            .borrow_mut()
213            .add_child(Rc::downgrade(&self.nodes[&index]));
214
215        let new_child = &self.nodes[&index];
216        new_child.borrow_mut().set_parent(Rc::downgrade(new_parent));
217
218        index
219    }
220
221    fn remove_communication_path(&mut self, partition_ids: &[usize]) {
222        for partition_id in partition_ids {
223            let mut parent_id = self.node(*partition_id).parent_id();
224            while let Some(tensor_id) = parent_id {
225                parent_id = self.node(tensor_id).parent_id();
226                self.nodes.remove(&tensor_id);
227            }
228        }
229    }
230
231    fn replace_communication_path(
232        &mut self,
233        partition_ids: Vec<usize>,
234        communication_path: SimplePathRef,
235    ) {
236        // Remove all nodes involved in communication path
237        self.remove_communication_path(&partition_ids);
238
239        // Rebuild the communication-part of the tree
240        let mut communication_ids = partition_ids;
241        let mut next_id = self.next_id(0);
242        for (i, j) in communication_path {
243            let left_child = communication_ids[*i];
244            let right_child = communication_ids[*j];
245            let new_parent =
246                parent_node(next_id, &self.nodes[&left_child], &self.nodes[&right_child]);
247            self.nodes.insert(next_id, new_parent);
248
249            communication_ids[*i] = next_id;
250            next_id = self.next_id(next_id);
251        }
252
253        // Update root
254        self.root = Rc::downgrade(self.nodes.iter().max_by_key(|entry| entry.0).unwrap().1);
255    }
256
257    fn tree_weights_recurse(
258        node: &Node,
259        tn: &Tensor,
260        weights: &mut FxHashMap<usize, f64>,
261        scratch: &mut FxHashMap<usize, Tensor>,
262        cost_function: fn(&Tensor, &Tensor) -> f64,
263    ) {
264        if node.is_leaf() {
265            let Some(tensor_index) = &node.tensor_index() else {
266                panic!("All leaf nodes should have a tensor index")
267            };
268            weights.insert(node.id(), 0f64);
269            scratch.insert(node.id(), tn.nested_tensor(tensor_index).clone());
270            return;
271        }
272
273        let left_child = &node.left_child().unwrap();
274        let right_child = &node.right_child().unwrap();
275        let left_ref = left_child.as_ref().borrow();
276        let right_ref = right_child.as_ref().borrow();
277
278        // Recurse first because weights of leaves are needed for further computation.
279        Self::tree_weights_recurse(&left_ref, tn, weights, scratch, cost_function);
280        Self::tree_weights_recurse(&right_ref, tn, weights, scratch, cost_function);
281
282        let t1 = &scratch[&left_ref.id()];
283        let t2 = &scratch[&right_ref.id()];
284
285        let cost = weights[&left_ref.id()] + weights[&right_ref.id()] + cost_function(t1, t2);
286
287        weights.insert(node.id(), cost);
288        scratch.insert(node.id(), t1 ^ t2);
289    }
290
291    /// Returns `HashMap` storing resultant tensor and its respective contraction costs calculated via `cost_function`.
292    ///
293    /// # Arguments
294    /// * `node_id` - root of Node to start calculating contraction costs
295    /// * `tn` - [`Tensor`] object containing bond dimension and leaf node information
296    /// * `cost_function` - cost function returning contraction cost
297    pub fn tree_weights(
298        &self,
299        node_id: usize,
300        tn: &Tensor,
301        cost_function: fn(&Tensor, &Tensor) -> f64,
302    ) -> FxHashMap<usize, f64> {
303        let mut weights = FxHashMap::default();
304        let mut scratch = FxHashMap::default();
305        let node = self.node(node_id);
306        Self::tree_weights_recurse(&node, tn, &mut weights, &mut scratch, cost_function);
307        weights
308    }
309
310    /// Populates given vector with contractions path of contraction tree starting at `node`.
311    ///
312    /// # Arguments
313    /// * `node` - pointer to Node object
314    /// * `path` - vec to store contraction path in
315    /// * `replace` - if set to `true` returns replace path, otherwise, returns in SSA format
316    fn to_contraction_path_recurse(node: &Node, path: &mut SimplePath, replace: bool) -> usize {
317        if node.is_leaf() {
318            return node.id();
319        }
320
321        // Get children
322        let (Some(left_child), Some(right_child)) = (node.left_child(), node.right_child()) else {
323            panic!("All parents should have two children")
324        };
325
326        // Get right and left child tensor ids
327        let mut t1_id =
328            Self::to_contraction_path_recurse(&left_child.as_ref().borrow(), path, replace);
329        let mut t2_id =
330            Self::to_contraction_path_recurse(&right_child.as_ref().borrow(), path, replace);
331        if t2_id < t1_id {
332            (t1_id, t2_id) = (t2_id, t1_id);
333        }
334
335        // Add pair to path
336        path.push((t1_id, t2_id));
337
338        // Return id of contracted tensor
339        if replace {
340            t1_id
341        } else {
342            node.id()
343        }
344    }
345
346    /// Populates given vector with contractions path of contraction tree starting at `node_id`.
347    /// # Arguments
348    /// * `node` - pointer to Node object
349    /// * `replace` - if set to `true` returns replace path, otherwise, returns in SSA format
350    pub fn to_flat_contraction_path(&self, node_id: usize, replace: bool) -> SimplePath {
351        let node = self.node(node_id);
352        let mut path = Vec::new();
353        Self::to_contraction_path_recurse(&node, &mut path, replace);
354        path
355    }
356
357    fn next_id(&self, mut init: usize) -> usize {
358        while self.nodes.contains_key(&init) {
359            init += 1;
360        }
361        init
362    }
363
364    /// Returns intermediate [`Tensor`] object corresponding to `node_id`.
365    ///
366    /// # Arguments
367    /// * `node_id` - id of Node corresponding to [`Tensor`] of interest
368    /// * `tensor` - tensor containing bond dimension and leaf node information
369    ///
370    /// # Returns
371    /// Empty tensor with legs (dimensions) of data after fully contracted.
372    pub fn tensor(&self, node_id: usize, tensor: &Tensor) -> Tensor {
373        let leaf_nodes = self.leaf_ids(node_id);
374        let mut new_tensor = Tensor::default();
375
376        for leaf_id in leaf_nodes {
377            new_tensor = &new_tensor
378                ^ tensor.nested_tensor(self.node(leaf_id).tensor_index().as_ref().unwrap());
379        }
380        new_tensor
381    }
382}
383
384fn populate_subtree_tensor_map_recursive(
385    contraction_tree: &ContractionTree,
386    node_id: usize,
387    node_tensor_map: &mut FxHashMap<usize, Tensor>,
388    tensor_network: &Tensor,
389    height_limit: Option<usize>,
390) -> (Tensor, usize) {
391    let node = contraction_tree.node(node_id);
392
393    if node.is_leaf() {
394        let tensor_index = node.tensor_index().unwrap();
395        let t = tensor_network.nested_tensor(tensor_index);
396        node_tensor_map.insert(node.id(), t.clone());
397        (t.clone(), 0)
398    } else {
399        let (t1, new_height1) = populate_subtree_tensor_map_recursive(
400            contraction_tree,
401            node.left_child_id().unwrap(),
402            node_tensor_map,
403            tensor_network,
404            height_limit,
405        );
406        let (t2, new_height2) = populate_subtree_tensor_map_recursive(
407            contraction_tree,
408            node.right_child_id().unwrap(),
409            node_tensor_map,
410            tensor_network,
411            height_limit,
412        );
413        let t12 = &t1 ^ &t2;
414        if let Some(height_limit) = height_limit {
415            if new_height1 < height_limit && new_height2 < height_limit {
416                node_tensor_map.insert(node.id(), t12.clone());
417            }
418        } else {
419            node_tensor_map.insert(node.id(), t12.clone());
420        }
421
422        (t12, new_height1.max(new_height2) + 1)
423    }
424}
425
426/// Populates `node_tensor_map` with all intermediate and leaf node ids and corresponding [`Tensor`] object, with root at `node_id`.
427/// Only inserts Tensors with up to `height_limit` number of contractions.
428///
429/// # Arguments
430/// * `contraction_tree` - [`ContractionTree`] object
431/// * `node_id` - root of subtree to examine
432/// * `node_tensor_map` - empty HashMap to populate
433/// * `tensor_network` - [`Tensor`] object containing bond dimension and leaf node information
434///
435///
436/// # Returns
437/// Populated HashMap mapping intermediate node ids up to `height_limit` to Tensor objects.
438fn populate_subtree_tensor_map(
439    contraction_tree: &ContractionTree,
440    node_id: usize,
441    tensor_network: &Tensor,
442    height_limit: Option<usize>,
443) -> FxHashMap<usize, Tensor> {
444    let mut node_tensor_map = FxHashMap::default();
445    let _ = populate_subtree_tensor_map_recursive(
446        contraction_tree,
447        node_id,
448        &mut node_tensor_map,
449        tensor_network,
450        height_limit,
451    );
452    node_tensor_map
453}
454
455/// Populates `node_tensor_map` with all leaf node ids and corresponding [`Tensor`] object, with root at `node_id`.
456///
457/// # Arguments
458/// * `contraction_tree` - [`ContractionTree`] object
459/// * `node_id` - root of subtree to examine
460/// * `tensor_network` - [`Tensor`] object containing bond dimension and leaf node information
461///
462/// # Returns
463/// Populated HashMap mapping leaf node ids to Tensor objects.
464fn populate_leaf_node_tensor_map(
465    contraction_tree: &ContractionTree,
466    node_id: usize,
467    tensor_network: &Tensor,
468) -> FxHashMap<usize, Tensor> {
469    let mut node_tensor_map = FxHashMap::default();
470    for leaf_node_id in contraction_tree.leaf_ids(node_id) {
471        node_tensor_map.insert(
472            leaf_node_id,
473            contraction_tree.tensor(leaf_node_id, tensor_network),
474        );
475    }
476    node_tensor_map
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    use std::cell::RefCell;
484    use std::iter::zip;
485    use std::rc::Weak;
486
487    use itertools::Itertools;
488
489    use crate::contractionpath::contraction_cost::contract_cost_tensors;
490    use crate::contractionpath::contraction_tree::node::{child_node, parent_node};
491    use crate::contractionpath::contraction_tree::{ContractionTree, Node};
492    use crate::contractionpath::ssa_replace_ordering;
493    use crate::path;
494    use crate::tensornetwork::tensor::{EdgeIndex, Tensor};
495
496    fn setup_simple() -> (Tensor, ContractionPath, FxHashMap<EdgeIndex, u64>) {
497        let bond_dims =
498            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
499        (
500            Tensor::new_composite(vec![
501                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
502                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
503                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
504            ]),
505            path![(0, 1), (2, 0)],
506            bond_dims,
507        )
508    }
509
510    fn setup_complex() -> (Tensor, ContractionPath, FxHashMap<EdgeIndex, u64>) {
511        let bond_dims = FxHashMap::from_iter([
512            (0, 27),
513            (1, 18),
514            (2, 12),
515            (3, 15),
516            (4, 5),
517            (5, 3),
518            (6, 18),
519            (7, 22),
520            (8, 45),
521            (9, 65),
522            (10, 5),
523        ]);
524        (
525            Tensor::new_composite(vec![
526                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
527                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
528                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
529                Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
530                Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
531                Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
532            ]),
533            path![(1, 5), (0, 1), (3, 4), (2, 3), (0, 2)],
534            bond_dims,
535        )
536    }
537
538    fn setup_unbalanced() -> (Tensor, ContractionPath) {
539        let bond_dims = FxHashMap::from_iter([
540            (0, 27),
541            (1, 18),
542            (2, 12),
543            (3, 15),
544            (4, 5),
545            (5, 3),
546            (6, 18),
547            (7, 22),
548            (8, 45),
549            (9, 65),
550            (10, 5),
551            (11, 17),
552        ]);
553        (
554            Tensor::new_composite(vec![
555                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
556                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
557                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
558                Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
559                Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
560                Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
561            ]),
562            path![(0, 1), (2, 0), (3, 2), (4, 3), (5, 4)],
563        )
564    }
565
566    fn setup_nested() -> (Tensor, ContractionPath) {
567        let bond_dims = FxHashMap::from_iter([
568            (0, 27),
569            (1, 18),
570            (2, 12),
571            (3, 15),
572            (4, 5),
573            (5, 3),
574            (6, 18),
575            (7, 22),
576            (8, 45),
577            (9, 65),
578            (10, 5),
579            (11, 17),
580        ]);
581
582        let t0 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
583        let t1 = Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims);
584        let t2 = Tensor::new_from_map(vec![4, 5, 6], &bond_dims);
585        let t3 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
586        let t4 = Tensor::new_from_map(vec![5, 1, 0], &bond_dims);
587        let t5 = Tensor::new_from_map(vec![10, 8, 9], &bond_dims);
588
589        let t01 = Tensor::new_composite(vec![t0, t1]);
590        let t23 = Tensor::new_composite(vec![t2, t3]);
591        let t45 = Tensor::new_composite(vec![t4, t5]);
592        let tensor_network = Tensor::new_composite(vec![t01, t23, t45]);
593        (
594            tensor_network,
595            path![{(0, [(0, 1)]), (1, [(0, 1)]), (2, [(0, 1)])}, (0, 1), (0, 2)],
596        )
597    }
598
599    fn setup_double_nested() -> (Tensor, ContractionPath) {
600        let bond_dims = FxHashMap::from_iter([
601            (0, 27),
602            (1, 18),
603            (2, 12),
604            (3, 15),
605            (4, 5),
606            (5, 3),
607            (6, 18),
608            (7, 22),
609            (8, 45),
610            (9, 65),
611            (10, 5),
612            (11, 17),
613        ]);
614
615        let t0 = Tensor::new_from_map(vec![4, 3, 2], &bond_dims);
616        let t1 = Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims);
617        let t2 = Tensor::new_from_map(vec![4, 5, 6], &bond_dims);
618        let t3 = Tensor::new_from_map(vec![6, 8, 9], &bond_dims);
619        let t4 = Tensor::new_from_map(vec![5, 1, 0], &bond_dims);
620        let t5 = Tensor::new_from_map(vec![10, 8, 9], &bond_dims);
621
622        let t01 = Tensor::new_composite(vec![t0, t1]);
623        let t012 = Tensor::new_composite(vec![t01, t2]);
624        let t34 = Tensor::new_composite(vec![t3, t4]);
625        let t345 = Tensor::new_composite(vec![t34, t5]);
626        let tensor_network = Tensor::new_composite(vec![t012, t345]);
627        (
628            tensor_network,
629            path![
630                {
631                (0, [{(0, [(0, 1)])}, (0, 1)]),
632                (1, [{(0, [(0, 1)])}, (0, 1)]),
633                },
634                (0, 1)
635            ],
636        )
637    }
638
639    impl PartialEq for Node {
640        fn eq(&self, other: &Self) -> bool {
641            self.id() == other.id()
642                && self.left_child_id() == other.left_child_id()
643                && self.right_child_id() == other.right_child_id()
644                && self.parent_id() == other.parent_id()
645                && self.tensor_index() == other.tensor_index()
646        }
647    }
648
649    #[test]
650    fn test_from_contraction_path_simple() {
651        let (tensor, path, _) = setup_simple();
652        let ContractionTree { nodes, root, .. } =
653            ContractionTree::from_contraction_path(&tensor, &path);
654
655        let node0 = child_node(0, vec![0]);
656        let node1 = child_node(1, vec![1]);
657        let node2 = child_node(2, vec![2]);
658
659        let node3 = parent_node(3, &node0, &node1);
660        let node4 = parent_node(4, &node2, &node3);
661
662        let ref_root = Rc::clone(&node4);
663        let ref_nodes = [node0, node1, node2, node3, node4];
664
665        for (key, ref_node) in ref_nodes.iter().enumerate() {
666            let node = &nodes[&key];
667            assert_eq!(node, ref_node);
668        }
669        assert_eq!(root.upgrade().unwrap(), ref_root);
670    }
671
672    #[test]
673    fn test_from_contraction_path_complex() {
674        let (tensor, path, _) = setup_complex();
675        let ContractionTree { nodes, root, .. } =
676            ContractionTree::from_contraction_path(&tensor, &path);
677
678        let node0 = child_node(0, vec![0]);
679        let node1 = child_node(1, vec![1]);
680        let node2 = child_node(2, vec![2]);
681        let node3 = child_node(3, vec![3]);
682        let node4 = child_node(4, vec![4]);
683        let node5 = child_node(5, vec![5]);
684
685        let node6 = parent_node(6, &node1, &node5);
686        let node7 = parent_node(7, &node0, &node6);
687        let node8 = parent_node(8, &node3, &node4);
688        let node9 = parent_node(9, &node2, &node8);
689        let node10 = parent_node(10, &node7, &node9);
690
691        let ref_root = Rc::clone(&node10);
692        let ref_nodes = [
693            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
694        ];
695
696        for (key, ref_node) in ref_nodes.iter().enumerate() {
697            let node = &nodes[&key];
698            assert_eq!(node, ref_node);
699        }
700        assert_eq!(root.upgrade().unwrap(), ref_root);
701    }
702
703    #[test]
704    fn test_from_contraction_path_nested() {
705        let (tensor, path) = setup_nested();
706        let ContractionTree { nodes, root, .. } =
707            ContractionTree::from_contraction_path(&tensor, &path);
708
709        let node0 = child_node(0, vec![0, 0]);
710        let node1 = child_node(1, vec![0, 1]);
711        let node3 = child_node(3, vec![1, 0]);
712        let node4 = child_node(4, vec![1, 1]);
713        let node6 = child_node(6, vec![2, 0]);
714        let node7 = child_node(7, vec![2, 1]);
715
716        let node2 = parent_node(2, &node0, &node1);
717        let node5 = parent_node(5, &node3, &node4);
718        let node8 = parent_node(8, &node6, &node7);
719        let node9 = parent_node(9, &node2, &node5);
720        let node10 = parent_node(10, &node9, &node8);
721
722        let ref_root = Rc::clone(&node10);
723        let ref_nodes = [
724            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
725        ];
726
727        for (key, ref_node) in ref_nodes.iter().enumerate() {
728            let node = &nodes[&key];
729            assert_eq!(node, ref_node);
730        }
731        assert_eq!(root.upgrade().unwrap(), ref_root);
732    }
733
734    #[test]
735    fn test_from_contraction_path_double_nested() {
736        let (tensor, path) = setup_double_nested();
737        let ContractionTree { nodes, root, .. } =
738            ContractionTree::from_contraction_path(&tensor, &path);
739
740        let node0 = child_node(0, vec![0, 0, 0]);
741        let node1 = child_node(1, vec![0, 0, 1]);
742        let node3 = child_node(3, vec![0, 1]);
743        let node5 = child_node(5, vec![1, 0, 0]);
744        let node6 = child_node(6, vec![1, 0, 1]);
745        let node8 = child_node(8, vec![1, 1]);
746
747        let node2 = parent_node(2, &node0, &node1);
748        let node4 = parent_node(4, &node2, &node3);
749        let node7 = parent_node(7, &node5, &node6);
750        let node9 = parent_node(9, &node7, &node8);
751        let node10 = parent_node(10, &node4, &node9);
752
753        let ref_root = Rc::clone(&node10);
754        let ref_nodes = [
755            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
756        ];
757
758        for (key, ref_node) in ref_nodes.iter().enumerate() {
759            let node = &nodes[&key];
760            assert_eq!(node, ref_node);
761        }
762        assert_eq!(root.upgrade().unwrap(), ref_root);
763    }
764
765    #[test]
766    fn test_leaf_ids_simple() {
767        let (tn, path, _) = setup_simple();
768        let tree = ContractionTree::from_contraction_path(&tn, &path);
769
770        assert_eq!(tree.leaf_ids(4), vec![2, 0, 1]);
771        assert_eq!(tree.leaf_ids(3), vec![0, 1]);
772        assert_eq!(tree.leaf_ids(2), vec![2]);
773    }
774
775    #[test]
776    fn test_leaf_ids_complex() {
777        let (tn, path, _) = setup_complex();
778        let tree = ContractionTree::from_contraction_path(&tn, &path);
779
780        assert_eq!(tree.leaf_ids(10), vec![0, 1, 5, 2, 3, 4]);
781        assert_eq!(tree.leaf_ids(9), vec![2, 3, 4]);
782        assert_eq!(tree.leaf_ids(8), vec![3, 4]);
783        assert_eq!(tree.leaf_ids(7), vec![0, 1, 5]);
784        assert_eq!(tree.leaf_ids(6), vec![1, 5]);
785        assert_eq!(tree.leaf_ids(3), vec![3]);
786    }
787
788    #[test]
789    fn test_leaf_ids_nested() {
790        let (tn, path) = setup_nested();
791        let tree = ContractionTree::from_contraction_path(&tn, &path);
792        assert_eq!(tree.leaf_ids(10), vec![0, 1, 3, 4, 6, 7]);
793        assert_eq!(tree.leaf_ids(9), vec![0, 1, 3, 4]);
794        assert_eq!(tree.leaf_ids(8), vec![6, 7]);
795        assert_eq!(tree.leaf_ids(5), vec![3, 4]);
796        assert_eq!(tree.leaf_ids(2), vec![0, 1]);
797    }
798
799    #[test]
800    fn test_leaf_ids_double_nested() {
801        let (tn, path) = setup_double_nested();
802        let tree = ContractionTree::from_contraction_path(&tn, &path);
803
804        assert_eq!(tree.leaf_ids(10), vec![0, 1, 3, 5, 6, 8]);
805        assert_eq!(tree.leaf_ids(9), vec![5, 6, 8]);
806        assert_eq!(tree.leaf_ids(7), vec![5, 6]);
807        assert_eq!(tree.leaf_ids(4), vec![0, 1, 3]);
808        assert_eq!(tree.leaf_ids(2), vec![0, 1]);
809    }
810
811    #[test]
812    fn test_remove_subtree() {
813        let (tn, path) = setup_nested();
814        let mut tree = ContractionTree::from_contraction_path(&tn, &path);
815
816        tree.remove_subtree(8);
817
818        let ContractionTree { nodes, root, .. } = tree;
819
820        let node0 = child_node(0, vec![0, 0]);
821        let node1 = child_node(1, vec![0, 1]);
822        let node3 = child_node(3, vec![1, 0]);
823        let node4 = child_node(4, vec![1, 1]);
824        let node6 = child_node(6, vec![2, 0]);
825        let node7 = child_node(7, vec![2, 1]);
826        let node2 = parent_node(2, &node0, &node1);
827        let node5 = parent_node(5, &node3, &node4);
828        let node9 = parent_node(9, &node2, &node5);
829        let node10 = Rc::new(RefCell::new(Node::new(
830            10,
831            Rc::downgrade(&node9),
832            Weak::new(),
833            Weak::new(),
834            None,
835        )));
836        node9.borrow_mut().set_parent(Rc::downgrade(&node10));
837
838        let ref_root = Rc::clone(&node10);
839        let ref_nodes = [
840            node0, node1, node2, node3, node4, node5, node6, node7, node9, node10,
841        ];
842        let mut range = (0..8).collect_vec();
843        range.extend(9..11);
844        for (key, ref_node) in zip(range.iter(), ref_nodes.iter()) {
845            let node = &nodes[key];
846            assert_eq!(node, ref_node);
847        }
848        assert_eq!(root.upgrade().unwrap(), ref_root);
849    }
850
851    #[test]
852    fn test_remove_trivial_subtree() {
853        let (tensor, path) = setup_nested();
854        let mut tree = ContractionTree::from_contraction_path(&tensor, &path);
855
856        tree.remove_subtree(7);
857
858        let ContractionTree { nodes, root, .. } = tree;
859
860        let node0 = child_node(0, vec![0, 0]);
861        let node1 = child_node(1, vec![0, 1]);
862        let node3 = child_node(3, vec![1, 0]);
863        let node4 = child_node(4, vec![1, 1]);
864        let node6 = child_node(6, vec![2, 0]);
865        let node7 = child_node(7, vec![2, 1]);
866        let node2 = parent_node(2, &node0, &node1);
867        let node5 = parent_node(5, &node3, &node4);
868        let node8 = Rc::new(RefCell::new(Node::new(
869            8,
870            Rc::downgrade(&node6),
871            Weak::new(),
872            Weak::new(),
873            None,
874        )));
875        let node9 = parent_node(9, &node2, &node5);
876        let node10 = parent_node(10, &node9, &node8);
877        node6.borrow_mut().set_parent(Rc::downgrade(&node8));
878
879        let ref_root = Rc::clone(&node10);
880        let ref_nodes = [
881            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
882        ];
883
884        for (key, ref_node) in ref_nodes.iter().enumerate() {
885            let node = &nodes[&key];
886            assert_eq!(node, ref_node);
887        }
888        assert_eq!(root.upgrade().unwrap(), ref_root);
889    }
890
891    #[test]
892    fn test_tree_weights_simple() {
893        let (tensor, path, _) = setup_simple();
894        let tree = ContractionTree::from_contraction_path(&tensor, &path);
895        let ref_weights = FxHashMap::from_iter([(1, 0.), (0, 0.), (2, 0.), (3, 3820.), (4, 4540.)]);
896        let weights = tree.tree_weights(4, &tensor, contract_cost_tensors);
897
898        assert_eq!(weights, ref_weights);
899        let ref_weights = FxHashMap::from_iter([(1, 0.), (0, 0.), (3, 3820.)]);
900        let weights = tree.tree_weights(3, &tensor, contract_cost_tensors);
901        assert_eq!(weights, ref_weights);
902
903        assert_eq!(weights, ref_weights);
904        let ref_weights = FxHashMap::from_iter([(2, 0.)]);
905        let weights = tree.tree_weights(2, &tensor, contract_cost_tensors);
906        assert_eq!(weights, ref_weights);
907    }
908
909    #[test]
910    fn test_tree_weights_complex() {
911        let (tensor, path, _) = setup_complex();
912        let tree = ContractionTree::from_contraction_path(&tensor, &path);
913        let ref_weights = FxHashMap::from_iter([
914            (0, 0.),
915            (1, 0.),
916            (2, 0.),
917            (3, 0.),
918            (4, 0.),
919            (5, 0.),
920            (6, 2098440.),
921            (7, 2120010.),
922            (8, 2105820.),
923            (9, 2116470.),
924            (10, 4237070.),
925        ]);
926        let weights = tree.tree_weights(10, &tensor, contract_cost_tensors);
927
928        assert_eq!(weights, ref_weights);
929    }
930
931    #[test]
932    fn test_to_contraction_path_simple() {
933        let (tensor, ref_path, _) = setup_simple();
934        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
935        let path = tree.to_flat_contraction_path(4, false);
936        let path = ssa_replace_ordering(&ContractionPath::simple(path));
937        assert_eq!(path, ref_path);
938    }
939
940    #[test]
941    fn test_to_contraction_path_complex() {
942        let (tensor, ref_path, _) = setup_complex();
943        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
944        let path = tree.to_flat_contraction_path(10, false);
945        let path = ssa_replace_ordering(&ContractionPath::simple(path));
946        assert_eq!(path, ref_path);
947    }
948
949    #[test]
950    fn test_to_contraction_path_unbalanced() {
951        let (tensor, ref_path) = setup_unbalanced();
952        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
953        let path = tree.to_flat_contraction_path(10, false);
954        let path = ssa_replace_ordering(&ContractionPath::simple(path));
955        assert_eq!(path, ref_path);
956    }
957
958    #[test]
959    fn test_populate_subtree_tensor_map_simple() {
960        let (tensor, ref_path, bond_dims) = setup_simple();
961        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
962        let mut node_tensor_map = FxHashMap::default();
963        populate_subtree_tensor_map_recursive(&tree, 4, &mut node_tensor_map, &tensor, None);
964
965        let ref_node_tensor_map = FxHashMap::from_iter([
966            (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
967            (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
968            (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
969            (3, Tensor::new_from_map(vec![4, 0, 1], &bond_dims)),
970            (4, Tensor::new_from_map(vec![5, 6, 0, 1], &bond_dims)),
971        ]);
972
973        for (key, value) in ref_node_tensor_map {
974            assert_eq!(node_tensor_map[&key].legs(), value.legs());
975        }
976    }
977
978    #[test]
979    fn test_populate_subtree_tensor_map_complex() {
980        let (tensor, ref_path, bond_dims) = setup_complex();
981        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
982        let mut node_tensor_map = FxHashMap::default();
983        populate_subtree_tensor_map_recursive(&tree, 10, &mut node_tensor_map, &tensor, None);
984
985        let ref_node_tensor_map = FxHashMap::from_iter([
986            (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
987            (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
988            (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
989            (3, Tensor::new_from_map(vec![6, 8, 9], &bond_dims)),
990            (4, Tensor::new_from_map(vec![10, 8, 9], &bond_dims)),
991            (5, Tensor::new_from_map(vec![5, 1, 0], &bond_dims)),
992            (6, Tensor::new_from_map(vec![3, 2, 5], &bond_dims)),
993            (7, Tensor::new_from_map(vec![4, 5], &bond_dims)),
994            (8, Tensor::new_from_map(vec![6, 10], &bond_dims)),
995            (9, Tensor::new_from_map(vec![4, 5, 10], &bond_dims)),
996            (10, Tensor::new_from_map(vec![10], &bond_dims)),
997        ]);
998
999        for (key, value) in ref_node_tensor_map {
1000            assert_eq!(node_tensor_map[&key].legs(), value.legs());
1001        }
1002    }
1003
1004    #[test]
1005    fn test_populate_subtree_tensor_map_height_limit() {
1006        let (tensor, ref_path, bond_dims) = setup_complex();
1007        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
1008        let node_tensor_map = populate_subtree_tensor_map(&tree, 10, &tensor, Some(1));
1009
1010        let ref_node_tensor_map = FxHashMap::from_iter([
1011            (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
1012            (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
1013            (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
1014            (3, Tensor::new_from_map(vec![6, 8, 9], &bond_dims)),
1015            (4, Tensor::new_from_map(vec![10, 8, 9], &bond_dims)),
1016            (5, Tensor::new_from_map(vec![5, 1, 0], &bond_dims)),
1017            (6, Tensor::new_from_map(vec![3, 2, 5], &bond_dims)),
1018            (8, Tensor::new_from_map(vec![6, 10], &bond_dims)),
1019        ]);
1020
1021        for (key, value) in ref_node_tensor_map {
1022            assert_eq!(node_tensor_map[&key].legs(), value.legs());
1023        }
1024    }
1025
1026    #[test]
1027    fn test_populate_leaf_node_tensor_map_simple() {
1028        let (tensor, ref_path, bond_dims) = setup_simple();
1029        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
1030
1031        let node_tensor_map = populate_leaf_node_tensor_map(&tree, 4, &tensor);
1032
1033        let ref_node_tensor_map = FxHashMap::from_iter([
1034            (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
1035            (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
1036            (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
1037        ]);
1038
1039        for (key, value) in ref_node_tensor_map {
1040            assert_eq!(node_tensor_map[&key].legs(), value.legs());
1041        }
1042    }
1043
1044    #[test]
1045    fn test_populate_leaf_node_tensor_map_complex() {
1046        let (tensor, ref_path, bond_dims) = setup_complex();
1047        let tree = ContractionTree::from_contraction_path(&tensor, &ref_path);
1048        let node_tensor_map = populate_subtree_tensor_map(&tree, 10, &tensor, None);
1049
1050        let ref_node_tensor_map = FxHashMap::from_iter([
1051            (0, Tensor::new_from_map(vec![4, 3, 2], &bond_dims)),
1052            (1, Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims)),
1053            (2, Tensor::new_from_map(vec![4, 5, 6], &bond_dims)),
1054            (3, Tensor::new_from_map(vec![6, 8, 9], &bond_dims)),
1055            (4, Tensor::new_from_map(vec![10, 8, 9], &bond_dims)),
1056        ]);
1057
1058        for (key, value) in ref_node_tensor_map {
1059            assert_eq!(node_tensor_map[&key].legs(), value.legs());
1060        }
1061    }
1062
1063    #[test]
1064    fn test_add_path_as_subtree() {
1065        let (tensor, path, _) = setup_complex();
1066
1067        let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1068        complex_tree.remove_subtree(9);
1069        let new_path = path![(4, 2), (4, 3)];
1070
1071        complex_tree.add_path_as_subtree(&new_path, 10, &[3, 4, 2]);
1072
1073        let ContractionTree { nodes, root, .. } = complex_tree;
1074
1075        let node0 = child_node(0, vec![0]);
1076        let node1 = child_node(1, vec![1]);
1077        let node2 = child_node(2, vec![2]);
1078        let node3 = child_node(3, vec![3]);
1079        let node4 = child_node(4, vec![4]);
1080        let node5 = child_node(5, vec![5]);
1081        let node6 = parent_node(6, &node1, &node5);
1082        let node7 = parent_node(7, &node0, &node6);
1083        let node8 = parent_node(8, &node4, &node2);
1084        let node9 = parent_node(9, &node8, &node3);
1085        let node10 = parent_node(10, &node7, &node9);
1086
1087        let ref_root = Rc::clone(&node10);
1088        let ref_nodes = [
1089            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
1090        ];
1091
1092        for (key, ref_node) in ref_nodes.iter().enumerate() {
1093            let node = &nodes[&key];
1094            assert_eq!(node, ref_node);
1095        }
1096        assert_eq!(root.upgrade().unwrap(), ref_root);
1097    }
1098
1099    #[test]
1100    #[should_panic = "Tensor 2 is already used in another contraction"]
1101    fn test_add_path_as_subtree_invalid_path() {
1102        let (tensor, path, _) = setup_complex();
1103
1104        let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1105        complex_tree.remove_subtree(8);
1106        let new_path = path![(4, 2), (4, 3)];
1107
1108        complex_tree.add_path_as_subtree(&new_path, 9, &[3, 4, 2]);
1109    }
1110
1111    #[test]
1112    fn test_remove_communication_path() {
1113        let (tensor, path) = setup_nested();
1114        let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1115        let partition_ids = vec![2, 5, 8];
1116        complex_tree.remove_communication_path(&partition_ids);
1117        assert!(!complex_tree.nodes.contains_key(&9));
1118        assert!(!complex_tree.nodes.contains_key(&10));
1119        assert!(complex_tree.root_id().is_none());
1120    }
1121
1122    #[test]
1123    fn test_replace_communication_path() {
1124        let (tensor, path) = setup_nested();
1125        let mut complex_tree = ContractionTree::from_contraction_path(&tensor, &path);
1126        let partition_ids = vec![2, 5, 8];
1127        complex_tree.replace_communication_path(partition_ids, &[(0, 2), (1, 0)]);
1128
1129        let ContractionTree { nodes, root, .. } = complex_tree;
1130
1131        let node2 = child_node(2, vec![]);
1132        let node5 = child_node(5, vec![]);
1133        let node8 = child_node(8, vec![]);
1134        let node9 = parent_node(9, &node2, &node8);
1135        let node10 = parent_node(10, &node5, &node9);
1136
1137        assert_eq!(nodes[&9], node9);
1138        assert_eq!(nodes[&10], node10);
1139        assert_eq!(root.upgrade().unwrap(), node10);
1140    }
1141}