tnc/contractionpath/paths/
cotengrust.rs

1use std::iter::zip;
2
3use cotengrust::{optimize_greedy_rust, optimize_optimal_rust, optimize_random_greedy_rust};
4use itertools::Itertools;
5use rustc_hash::FxHashMap;
6
7use crate::contractionpath::contraction_cost::contract_path_cost;
8use crate::contractionpath::paths::FindPath;
9use crate::contractionpath::{ssa_replace_ordering, ContractionPath, SimplePath};
10use crate::tensornetwork::tensor::Tensor;
11
12/// The optimization method to use.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum OptMethod {
15    /// Searches for the optimal path, quite slow.
16    Optimal,
17    /// Uses a greedy algorithm to find a path.
18    Greedy,
19    /// Tries multiple greedy paths and selects the best one.
20    RandomGreedy(usize),
21}
22
23/// A contraction path finder using the `cotengrust` library.
24#[derive(Debug, Clone)]
25pub struct Cotengrust<'a> {
26    tensor: &'a Tensor,
27    best_path: ContractionPath,
28    best_flops: f64,
29    best_size: f64,
30    opt_method: OptMethod,
31}
32
33impl<'a> Cotengrust<'a> {
34    /// Creates a new Cotengrust optimizer using the specified optimization method.
35    pub fn new(tensor: &'a Tensor, opt_method: OptMethod) -> Self {
36        Self {
37            tensor,
38            opt_method,
39            best_path: ContractionPath::default(),
40            best_flops: f64::INFINITY,
41            best_size: f64::INFINITY,
42        }
43    }
44
45    /// Finds a contraction path for a "classical" tensor network, i.e. the inputs
46    /// are all leaf tensors.
47    fn optimize_single(&self, inputs: &[Tensor], output: &Tensor) -> SimplePath {
48        // Check if the inputs are empty (cotengrust does not handle this gracefully)
49        if inputs.is_empty() {
50            return SimplePath::default();
51        }
52
53        // Convert the inputs to the cotengra format
54        let (inputs, output, size_dict) = tensor_legs_to_digit(inputs, output);
55
56        // Find the contraction path
57        let path = match &self.opt_method {
58            OptMethod::Greedy => optimize_greedy_rust(
59                inputs,
60                output,
61                size_dict,
62                None,
63                None,
64                None,
65                Some(42),
66                false,
67                true,
68            ),
69            &OptMethod::RandomGreedy(ntrials) => {
70                optimize_random_greedy_rust(
71                    inputs,
72                    output,
73                    size_dict,
74                    ntrials,
75                    None,
76                    None,
77                    None,
78                    Some(42),
79                    false,
80                    true,
81                )
82                .0
83            }
84            OptMethod::Optimal => {
85                optimize_optimal_rust(inputs, output, size_dict, None, None, None, false, true)
86            }
87        };
88
89        // Convert the path back to our format
90        path.into_iter()
91            .map(|pair| {
92                let [a, b] = pair[..] else {
93                    panic!("Expected two indices in contraction path pair")
94                };
95                (a as _, b as _)
96            })
97            .collect_vec()
98    }
99}
100
101/// Converts tensor leg inputs to chars. Creates new inputs, outputs and size_dict that can be fed to Cotengra.
102fn tensor_legs_to_digit(
103    inputs: &[Tensor],
104    output: &Tensor,
105) -> (Vec<Vec<char>>, Vec<char>, FxHashMap<char, f32>) {
106    fn leg_to_char(leg: usize) -> char {
107        char::from_u32(leg.try_into().unwrap()).unwrap()
108    }
109    let mut new_inputs = vec![Vec::new(); inputs.len()];
110    let new_output = output.legs().iter().copied().map(leg_to_char).collect();
111    let mut new_size_dict = FxHashMap::default();
112
113    for (tensor, labels) in zip(inputs, new_inputs.iter_mut()) {
114        labels.reserve_exact(tensor.legs().len());
115        for (leg, dim) in tensor.edges() {
116            let character = leg_to_char(*leg);
117            labels.push(character);
118            new_size_dict.insert(character, *dim as f32);
119        }
120    }
121    (new_inputs, new_output, new_size_dict)
122}
123
124impl FindPath for Cotengrust<'_> {
125    fn find_path(&mut self) {
126        // Handle nested tensors first
127        let mut nested_paths = FxHashMap::default();
128        let mut inputs = self.tensor.tensors().clone();
129        for (index, input_tensor) in inputs.iter_mut().enumerate() {
130            if input_tensor.is_composite() {
131                let mut ct = Cotengrust::new(input_tensor, self.opt_method);
132                ct.find_path();
133                nested_paths.insert(index, ct.get_best_path().clone());
134                *input_tensor = input_tensor.external_tensor();
135            }
136        }
137
138        // Now handle the outer tensor
139        let external_tensor = self.tensor.external_tensor();
140        let outer_path = self.optimize_single(&inputs, &external_tensor);
141        self.best_path = ContractionPath {
142            nested: nested_paths,
143            toplevel: outer_path,
144        };
145
146        // Compute the cost
147        let (op_cost, mem_cost) =
148            contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
149        self.best_size = mem_cost;
150        self.best_flops = op_cost;
151    }
152
153    fn get_best_path(&self) -> &ContractionPath {
154        &self.best_path
155    }
156
157    fn get_best_replace_path(&self) -> ContractionPath {
158        ssa_replace_ordering(&self.best_path)
159    }
160
161    fn get_best_flops(&self) -> f64 {
162        self.best_flops
163    }
164
165    fn get_best_size(&self) -> f64 {
166        self.best_size
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    use crate::path;
175
176    fn setup_simple() -> Tensor {
177        let bond_dims =
178            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
179        Tensor::new_composite(vec![
180            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
181            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
182            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
183        ])
184    }
185
186    fn setup_complex() -> Tensor {
187        let bond_dims = FxHashMap::from_iter([
188            (0, 27),
189            (1, 18),
190            (2, 12),
191            (3, 15),
192            (4, 5),
193            (5, 3),
194            (6, 18),
195            (7, 22),
196            (8, 45),
197            (9, 65),
198            (10, 5),
199            (11, 17),
200        ]);
201        Tensor::new_composite(vec![
202            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
203            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
204            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
205            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
206            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
207            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
208        ])
209    }
210
211    fn setup_simple_inner_product() -> Tensor {
212        let bond_dims =
213            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
214        Tensor::new_composite(vec![
215            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
216            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
217            Tensor::new_from_map(vec![0, 1, 5], &bond_dims),
218            Tensor::new_from_map(vec![1, 6], &bond_dims),
219        ])
220    }
221
222    fn setup_simple_outer_product() -> Tensor {
223        let bond_dims = FxHashMap::from_iter([(0, 3), (1, 2), (2, 2)]);
224        Tensor::new_composite(vec![
225            Tensor::new_from_map(vec![0], &bond_dims),
226            Tensor::new_from_map(vec![1], &bond_dims),
227            Tensor::new_from_map(vec![2], &bond_dims),
228        ])
229    }
230
231    fn setup_complex_outer_product() -> Tensor {
232        let bond_dims = FxHashMap::from_iter([(0, 5), (1, 4)]);
233        Tensor::new_composite(vec![
234            Tensor::new_from_map(vec![0], &bond_dims),
235            Tensor::new_from_map(vec![0], &bond_dims),
236            Tensor::new_from_map(vec![1], &bond_dims),
237            Tensor::new_from_map(vec![1], &bond_dims),
238        ])
239    }
240
241    #[test]
242    fn test_contract_order_greedy_simple() {
243        let tn = setup_simple();
244        let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
245        opt.find_path();
246
247        assert_eq!(opt.get_best_flops(), 600.);
248        assert_eq!(opt.get_best_size(), 538.);
249        assert_eq!(opt.get_best_path(), &path![(0, 1), (3, 2)]);
250        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (0, 2)]);
251    }
252
253    #[test]
254    fn test_contract_order_greedy_simple_inner() {
255        let tn = setup_simple_inner_product();
256        let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
257        opt.find_path();
258
259        assert_eq!(opt.get_best_flops(), 228.);
260        assert_eq!(opt.get_best_size(), 121.);
261        assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3), (4, 5)]);
262        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 3), (0, 2)]);
263    }
264
265    #[test]
266    fn test_contract_order_greedy_simple_outer() {
267        let tn = setup_simple_outer_product();
268        let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
269        opt.find_path();
270
271        assert_eq!(opt.get_best_flops(), 16.);
272        assert_eq!(opt.get_best_size(), 19.);
273        assert_eq!(opt.get_best_path(), &path![(2, 1), (0, 3)]);
274        assert_eq!(opt.get_best_replace_path(), path![(2, 1), (0, 2)]);
275    }
276
277    #[test]
278    fn test_contract_order_greedy_complex_outer() {
279        let tn = setup_complex_outer_product();
280        let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
281        opt.find_path();
282
283        assert_eq!(opt.get_best_flops(), 10.);
284        assert_eq!(opt.get_best_size(), 11.);
285        assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3), (5, 4)]);
286        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 3), (2, 0)]);
287    }
288
289    #[test]
290    fn test_contract_order_greedy_complex() {
291        let tn = setup_complex();
292        let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
293        opt.find_path();
294
295        assert_eq!(opt.get_best_flops(), 529815.);
296        assert_eq!(opt.get_best_size(), 89478.);
297        assert_eq!(
298            opt.get_best_path(),
299            &path![(1, 5), (3, 4), (6, 0), (7, 2), (9, 8)]
300        );
301        assert_eq!(
302            opt.get_best_replace_path(),
303            path![(1, 5), (3, 4), (1, 0), (3, 2), (3, 1)]
304        );
305    }
306}