tnc/contractionpath/paths/
tree_annealing.rs

1use itertools::Itertools;
2use rustc_hash::FxHashMap;
3use rustengra::{cotengra_check, cotengra_sa_tree, replace_to_ssa_path, tensor_legs_to_digit};
4
5use crate::{
6    contractionpath::{
7        contraction_cost::contract_path_cost,
8        paths::{CostType, FindPath},
9        ssa_replace_ordering, ContractionPath,
10    },
11    tensornetwork::tensor::Tensor,
12};
13
14/// Creates an interface to `rustengra` an interface to access `Cotengra` methods in
15/// Rust. Specifically exposes `simulated_anneal_tree` method.
16pub struct TreeAnnealing<'a> {
17    tensor: &'a Tensor,
18    temperature_steps: Option<usize>,
19    numiter: Option<usize>,
20    best_flops: f64,
21    best_size: f64,
22    best_path: ContractionPath,
23    seed: Option<u64>,
24}
25
26impl<'a> TreeAnnealing<'a> {
27    pub fn new(
28        tensor: &'a Tensor,
29        seed: Option<u64>,
30        minimize: CostType,
31        temperature_steps: Option<usize>,
32        numiter: Option<usize>,
33    ) -> Self {
34        assert!(cotengra_check().is_ok());
35        assert_eq!(
36            minimize,
37            CostType::Flops,
38            "Currently, only Flops is supported"
39        );
40        Self {
41            tensor,
42            temperature_steps,
43            numiter,
44            best_flops: f64::INFINITY,
45            best_size: f64::INFINITY,
46            best_path: ContractionPath::default(),
47            seed,
48        }
49    }
50}
51
52impl FindPath for TreeAnnealing<'_> {
53    fn find_path(&mut self) {
54        // Map tensors to legs
55        let inputs = self
56            .tensor
57            .tensors()
58            .iter()
59            .map(|tensor| tensor.legs().clone())
60            .collect_vec();
61        let outputs = self.tensor.external_tensor();
62        let size_dict = self.tensor.tensors().iter().map(Tensor::edges).fold(
63            FxHashMap::default(),
64            |mut acc, edges| {
65                acc.extend(edges);
66                acc
67            },
68        );
69
70        let (inputs, outputs, size_dict) =
71            tensor_legs_to_digit(&inputs, outputs.legs(), &size_dict);
72
73        let replace_path = cotengra_sa_tree(
74            &inputs,
75            outputs,
76            self.temperature_steps,
77            self.numiter,
78            size_dict,
79            self.seed,
80        )
81        .unwrap();
82
83        let best_path = replace_to_ssa_path(replace_path, self.tensor.tensors().len());
84
85        self.best_path = ContractionPath::simple(best_path);
86
87        let (op_cost, mem_cost) =
88            contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
89
90        self.best_flops = op_cost;
91        self.best_size = mem_cost;
92    }
93
94    fn get_best_flops(&self) -> f64 {
95        self.best_flops
96    }
97
98    fn get_best_size(&self) -> f64 {
99        self.best_size
100    }
101
102    fn get_best_path(&self) -> &ContractionPath {
103        &self.best_path
104    }
105
106    fn get_best_replace_path(&self) -> ContractionPath {
107        ssa_replace_ordering(&self.best_path)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    use rustc_hash::FxHashMap;
116
117    use crate::{
118        contractionpath::paths::{CostType, FindPath},
119        path,
120        tensornetwork::tensor::Tensor,
121    };
122
123    fn setup_simple() -> Tensor {
124        let bond_dims =
125            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
126        Tensor::new_composite(vec![
127            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
128            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
129            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
130        ])
131    }
132
133    fn setup_complex() -> Tensor {
134        let bond_dims = FxHashMap::from_iter([
135            (0, 27),
136            (1, 18),
137            (2, 12),
138            (3, 15),
139            (4, 5),
140            (5, 3),
141            (6, 18),
142            (7, 22),
143            (8, 45),
144            (9, 65),
145            (10, 5),
146            (11, 17),
147        ]);
148        Tensor::new_composite(vec![
149            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
150            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
151            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
152            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
153            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
154            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
155        ])
156    }
157
158    #[test]
159    fn test_anneal_tree_contract_order_simple() {
160        let tn = setup_simple();
161        let mut opt = TreeAnnealing::new(&tn, Some(8), CostType::Flops, Some(100), Some(50));
162        opt.find_path();
163
164        assert_eq!(opt.best_flops, 600.);
165        assert_eq!(opt.best_size, 538.);
166        assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
167        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
168    }
169
170    #[test]
171    fn test_anneal_tree_contract_order_complex() {
172        let tn = setup_complex();
173        let mut opt = TreeAnnealing::new(&tn, Some(8), CostType::Flops, Some(100), Some(50));
174        opt.find_path();
175
176        assert_eq!(opt.best_flops, 332685.);
177        assert_eq!(opt.best_size, 89478.);
178        assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
179        assert_eq!(
180            opt.get_best_replace_path(),
181            path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
182        );
183    }
184}