tnc/contractionpath/paths/
tree_tempering.rs

1use itertools::Itertools;
2use rustc_hash::FxHashMap;
3use rustengra::{
4    cotengra_check, cotengra_tree_tempering, replace_to_ssa_path, tensor_legs_to_digit,
5};
6
7use crate::{
8    contractionpath::{
9        contraction_cost::contract_path_cost,
10        paths::{CostType, FindPath},
11        ssa_replace_ordering, ContractionPath,
12    },
13    tensornetwork::tensor::Tensor,
14};
15
16/// Creates an interface to `rustengra` an interface to access `Cotengra` methods in
17/// Rust. Specifically exposes `parallel_temper_tree` method.
18pub struct TreeTempering<'a> {
19    tensor: &'a Tensor,
20    numiter: Option<usize>,
21    best_flops: f64,
22    best_size: f64,
23    best_path: ContractionPath,
24    seed: Option<u64>,
25}
26
27impl<'a> TreeTempering<'a> {
28    pub fn new(
29        tensor: &'a Tensor,
30        seed: Option<u64>,
31        minimize: CostType,
32        numiter: Option<usize>,
33    ) -> Self {
34        assert!(cotengra_check().is_ok());
35        assert_eq!(
36            minimize,
37            CostType::Flops,
38            "Currently, only Flops is supported"
39        );
40        Self {
41            tensor,
42            numiter,
43            best_flops: f64::INFINITY,
44            best_size: f64::INFINITY,
45            best_path: ContractionPath::default(),
46            seed,
47        }
48    }
49}
50
51impl FindPath for TreeTempering<'_> {
52    fn find_path(&mut self) {
53        // Map tensors to legs
54        let inputs = self
55            .tensor
56            .tensors()
57            .iter()
58            .map(|tensor| tensor.legs().clone())
59            .collect_vec();
60        let outputs = self.tensor.external_tensor();
61        let size_dict = self.tensor.tensors().iter().map(Tensor::edges).fold(
62            FxHashMap::default(),
63            |mut acc, edges| {
64                acc.extend(edges);
65                acc
66            },
67        );
68
69        let (inputs, outputs, size_dict) =
70            tensor_legs_to_digit(&inputs, outputs.legs(), &size_dict);
71
72        let replace_path =
73            cotengra_tree_tempering(&inputs, outputs, self.numiter, size_dict, self.seed).unwrap();
74
75        let best_path = replace_to_ssa_path(replace_path, self.tensor.tensors().len());
76
77        self.best_path = ContractionPath::simple(best_path);
78
79        let (op_cost, mem_cost) =
80            contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
81
82        self.best_flops = op_cost;
83        self.best_size = mem_cost;
84    }
85
86    fn get_best_flops(&self) -> f64 {
87        self.best_flops
88    }
89
90    fn get_best_size(&self) -> f64 {
91        self.best_size
92    }
93
94    fn get_best_path(&self) -> &ContractionPath {
95        &self.best_path
96    }
97
98    fn get_best_replace_path(&self) -> ContractionPath {
99        ssa_replace_ordering(&self.best_path)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    use rustc_hash::FxHashMap;
108
109    use crate::{
110        contractionpath::paths::{CostType, FindPath},
111        path,
112        tensornetwork::tensor::Tensor,
113    };
114
115    fn setup_simple() -> Tensor {
116        let bond_dims =
117            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
118        Tensor::new_composite(vec![
119            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
120            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
121            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
122        ])
123    }
124
125    fn setup_complex() -> Tensor {
126        let bond_dims = FxHashMap::from_iter([
127            (0, 27),
128            (1, 18),
129            (2, 12),
130            (3, 15),
131            (4, 5),
132            (5, 3),
133            (6, 18),
134            (7, 22),
135            (8, 45),
136            (9, 65),
137            (10, 5),
138            (11, 17),
139        ]);
140        Tensor::new_composite(vec![
141            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
142            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
143            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
144            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
145            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
146            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
147        ])
148    }
149
150    #[test]
151    #[ignore = "flaky test due to a bug in cotengra"]
152    fn test_temper_tree_contract_order_simple() {
153        let tn = setup_simple();
154        let mut opt = TreeTempering::new(&tn, Some(8), CostType::Flops, Some(100));
155        opt.find_path();
156
157        assert_eq!(opt.best_flops, 600.);
158        assert_eq!(opt.best_size, 538.);
159        assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
160        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
161    }
162
163    #[test]
164    #[ignore = "flaky test due to a bug in cotengra"]
165    fn test_temper_tree_contract_order_complex() {
166        let tn = setup_complex();
167        let mut opt = TreeTempering::new(&tn, Some(8), CostType::Flops, Some(100));
168        opt.find_path();
169
170        assert_eq!(opt.best_flops, 332685.);
171        assert_eq!(opt.best_size, 89478.);
172        assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
173        assert_eq!(
174            opt.get_best_replace_path(),
175            path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
176        );
177    }
178}