tnc/contractionpath/paths/
tree_reconfiguration.rs

1use itertools::Itertools;
2use rustc_hash::FxHashMap;
3use rustengra::{
4    cotengra_check, cotengra_optimized_greedy, replace_to_ssa_path, tensor_legs_to_digit,
5};
6
7use crate::{
8    contractionpath::{
9        contraction_cost::contract_path_cost,
10        paths::{CostType, FindPath},
11        ssa_replace_ordering, ContractionPath,
12    },
13    tensornetwork::tensor::Tensor,
14};
15
16/// Creates an interface to `rustengra` an interface to access `Cotengra` methods in
17/// Rust. Specifically exposes `subtree_reconfigure` method.
18pub struct TreeReconfigure<'a> {
19    tensor: &'a Tensor,
20    subtree_size: usize,
21    best_flops: f64,
22    best_size: f64,
23    best_path: ContractionPath,
24}
25
26impl<'a> TreeReconfigure<'a> {
27    /// Creates a new [`TreeReconfigure`] instance. `subtree_size` is the
28    /// size of subtrees that is considered (increases the optimization cost
29    /// exponentially!).
30    pub fn new(tensor: &'a Tensor, subtree_size: usize, minimize: CostType) -> Self {
31        assert!(cotengra_check().is_ok());
32        assert_eq!(
33            minimize,
34            CostType::Flops,
35            "Currently, only Flops is supported"
36        );
37        Self {
38            tensor,
39            subtree_size,
40            best_flops: f64::INFINITY,
41            best_size: f64::INFINITY,
42            best_path: ContractionPath::default(),
43        }
44    }
45}
46
47impl FindPath for TreeReconfigure<'_> {
48    fn find_path(&mut self) {
49        // Map tensors to legs
50        let inputs = self
51            .tensor
52            .tensors()
53            .iter()
54            .map(|tensor| tensor.legs().clone())
55            .collect_vec();
56        let outputs = self.tensor.external_tensor();
57        let size_dict = self.tensor.tensors().iter().map(Tensor::edges).fold(
58            FxHashMap::default(),
59            |mut acc, edges| {
60                acc.extend(edges);
61                acc
62            },
63        );
64
65        let (inputs, outputs, size_dict) =
66            tensor_legs_to_digit(&inputs, outputs.legs(), &size_dict);
67
68        let replace_path =
69            cotengra_optimized_greedy(&inputs, outputs, size_dict, self.subtree_size).unwrap();
70
71        let best_path = replace_to_ssa_path(replace_path, self.tensor.tensors().len());
72
73        self.best_path = ContractionPath::simple(best_path);
74
75        let (op_cost, mem_cost) =
76            contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
77
78        self.best_flops = op_cost;
79        self.best_size = mem_cost;
80    }
81
82    fn get_best_flops(&self) -> f64 {
83        self.best_flops
84    }
85
86    fn get_best_size(&self) -> f64 {
87        self.best_size
88    }
89
90    fn get_best_path(&self) -> &ContractionPath {
91        &self.best_path
92    }
93
94    fn get_best_replace_path(&self) -> ContractionPath {
95        ssa_replace_ordering(&self.best_path)
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    use rustc_hash::FxHashMap;
104
105    use crate::{
106        contractionpath::paths::{CostType, FindPath},
107        path,
108        tensornetwork::tensor::Tensor,
109    };
110
111    fn setup_simple() -> Tensor {
112        let bond_dims =
113            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
114        Tensor::new_composite(vec![
115            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
116            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
117            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
118        ])
119    }
120
121    fn setup_complex() -> Tensor {
122        let bond_dims = FxHashMap::from_iter([
123            (0, 27),
124            (1, 18),
125            (2, 12),
126            (3, 15),
127            (4, 5),
128            (5, 3),
129            (6, 18),
130            (7, 22),
131            (8, 45),
132            (9, 65),
133            (10, 5),
134            (11, 17),
135        ]);
136        Tensor::new_composite(vec![
137            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
138            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
139            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
140            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
141            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
142            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
143        ])
144    }
145
146    #[test]
147    fn test_tree_contract_order_simple() {
148        let tn = setup_simple();
149        let mut opt = TreeReconfigure::new(&tn, 8, CostType::Flops);
150        opt.find_path();
151
152        assert_eq!(opt.best_flops, 600.);
153        assert_eq!(opt.best_size, 538.);
154        assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
155        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
156    }
157
158    #[test]
159    fn test_tree_contract_order_complex() {
160        let tn = setup_complex();
161        let mut opt = TreeReconfigure::new(&tn, 8, CostType::Flops);
162        opt.find_path();
163
164        assert_eq!(opt.best_flops, 332685.);
165        assert_eq!(opt.best_size, 89478.);
166        assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
167        assert_eq!(
168            opt.get_best_replace_path(),
169            path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
170        );
171    }
172}