tnc/contractionpath/paths/
branchbound.rs

1use std::collections::BinaryHeap;
2
3use itertools::Itertools;
4use rustc_hash::FxHashMap;
5
6use crate::{
7    contractionpath::{
8        candidates::Candidate,
9        contraction_cost::{contract_cost_tensors, contract_size_tensors},
10        paths::{CostType, FindPath},
11        ssa_ordering, ssa_replace_ordering, ContractionPath,
12    },
13    tensornetwork::tensor::Tensor,
14    utils::traits::HashMapInsertNew,
15};
16
17/// A struct with an [`FindPath`] implementation that explores possible pair contractions in a depth-first manner.
18pub struct BranchBound<'a> {
19    tn: &'a Tensor,
20    nbranch: Option<usize>,
21    cutoff_flops_factor: f64,
22    minimize: CostType,
23    best_flops: f64,
24    best_size: f64,
25    best_path: ContractionPath,
26    best_progress: FxHashMap<usize, f64>,
27    result_cache: FxHashMap<(usize, usize), usize>,
28    flop_cache: FxHashMap<usize, f64>,
29    size_cache: FxHashMap<usize, f64>,
30    tensor_cache: FxHashMap<usize, Tensor>,
31}
32
33impl<'a> BranchBound<'a> {
34    pub fn new(
35        tn: &'a Tensor,
36        nbranch: Option<usize>,
37        cutoff_flops_factor: f64,
38        minimize: CostType,
39    ) -> Self {
40        Self {
41            tn,
42            nbranch,
43            cutoff_flops_factor,
44            minimize,
45            best_flops: f64::INFINITY,
46            best_size: f64::INFINITY,
47            best_path: ContractionPath::default(),
48            best_progress: FxHashMap::default(),
49            result_cache: FxHashMap::default(),
50            flop_cache: FxHashMap::default(),
51            size_cache: FxHashMap::default(),
52            tensor_cache: FxHashMap::default(),
53        }
54    }
55
56    fn assess_candidate(
57        &mut self,
58        mut i: usize,
59        mut j: usize,
60        flops: f64,
61        size: f64,
62        remaining_len: usize,
63    ) -> Option<Candidate> {
64        let flops_12: f64;
65        let size_12: f64;
66        let k12: usize;
67        let k12_tensor: Tensor;
68        let mut current_flops = flops;
69        let mut current_size = size;
70        // Ensure that larger tensor is always to the left.
71        if self.tensor_cache[&j].size() > self.tensor_cache[&i].size() {
72            (i, j) = (j, i);
73        }
74
75        if self.result_cache.contains_key(&(i, j)) {
76            k12 = self.result_cache[&(i, j)];
77            flops_12 = self.flop_cache[&k12];
78            size_12 = self.size_cache[&k12];
79        } else {
80            k12 = self.tensor_cache.len();
81            flops_12 = contract_cost_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
82            size_12 = contract_size_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
83            k12_tensor = &self.tensor_cache[&i] ^ &self.tensor_cache[&j];
84
85            self.result_cache.entry((i, j)).or_insert_with(|| k12);
86            self.flop_cache.entry(k12).or_insert_with(|| flops_12);
87            self.size_cache.entry(k12).or_insert_with(|| size_12);
88            self.tensor_cache.insert_new(k12, k12_tensor);
89        }
90        current_flops += flops_12;
91        current_size = current_size.max(size_12);
92
93        if current_flops > self.best_flops && current_size > self.best_size {
94            return None;
95        }
96        let best_flops = *self
97            .best_progress
98            .entry(remaining_len)
99            .or_insert(current_flops);
100
101        if current_flops < best_flops {
102            self.best_progress.insert(remaining_len, current_flops);
103        } else if current_flops > self.cutoff_flops_factor * best_flops {
104            return None;
105        }
106
107        Some(Candidate {
108            flop_cost: current_flops,
109            size_cost: current_size,
110            parent_ids: (i, j),
111            child_id: k12,
112        })
113    }
114
115    /// Explores possible pair contractions in a depth-first
116    /// recursive manner like the `optimal` approach, but with extra heuristic early pruning of branches
117    /// as well sieving by `memory_limit` and the best path found so far. A rust implementation of
118    /// the Python based `opt_einsum` implementation. Found at <https://github.com/dgasmith/opt_einsum>.
119    fn branch_iterate(
120        &mut self,
121        path: &[(usize, usize, usize)],
122        remaining: &[usize],
123        flops: f64,
124        size: f64,
125    ) {
126        if remaining.len() == 1 {
127            match self.minimize {
128                CostType::Flops => {
129                    if self.best_flops > flops {
130                        self.best_flops = flops;
131                        self.best_size = size;
132                        self.best_path = ssa_ordering(path, self.tn.tensors().len());
133                    }
134                }
135                CostType::Size => {
136                    if self.best_size > size {
137                        self.best_flops = flops;
138                        self.best_size = size;
139                        self.best_path = ssa_ordering(path, self.tn.tensors().len());
140                    }
141                }
142            }
143            return;
144        }
145
146        let mut candidates = BinaryHeap::<Candidate>::new();
147        for i in remaining.iter().copied().combinations(2) {
148            let candidate = self.assess_candidate(i[0], i[1], flops, size, remaining.len());
149            if let Some(new_candidate) = candidate {
150                candidates.push(new_candidate);
151            }
152        }
153        let mut new_remaining;
154        let mut new_path: Vec<(usize, usize, usize)>;
155        let mut bi = 0;
156        while self.nbranch.is_none() || bi < self.nbranch.unwrap() {
157            bi += 1;
158            let Some(Candidate {
159                flop_cost,
160                size_cost,
161                parent_ids,
162                child_id,
163            }) = candidates.pop()
164            else {
165                break;
166            };
167            new_remaining = remaining.to_vec();
168            new_remaining.retain(|e| *e != parent_ids.0 && *e != parent_ids.1);
169            new_remaining.push(child_id);
170            new_path = path.to_vec();
171            new_path.push((parent_ids.0, parent_ids.1, child_id));
172            self.branch_iterate(&new_path, &new_remaining, flop_cost, size_cost);
173        }
174    }
175}
176
177impl FindPath for BranchBound<'_> {
178    fn find_path(&mut self) {
179        if self.tn.is_leaf() {
180            return;
181        }
182        let tensors = self.tn.tensors().clone();
183        self.flop_cache.clear();
184        self.size_cache.clear();
185        self.result_cache.clear();
186        self.tensor_cache.clear();
187        let mut nested_paths = FxHashMap::default();
188        // Get the initial space requirements for uncontracted tensors
189        for (index, mut tensor) in tensors.into_iter().enumerate() {
190            // Check that tensor has sub-tensors and doesn't have external legs set
191            if !tensor.tensors().is_empty() && tensor.legs().is_empty() {
192                let mut bb = BranchBound::new(
193                    &tensor,
194                    self.nbranch,
195                    self.cutoff_flops_factor,
196                    self.minimize,
197                );
198                bb.find_path();
199                nested_paths.insert(index, bb.get_best_path().clone());
200                tensor = tensor.external_tensor();
201            }
202            self.size_cache
203                .entry(index)
204                .or_insert_with(|| tensor.size());
205
206            self.tensor_cache.insert_new(index, tensor);
207        }
208        let remaining = (0..self.tn.tensors().len()).collect_vec();
209        self.branch_iterate(&[], &remaining, 0f64, 0f64);
210        self.best_path = ContractionPath {
211            nested: nested_paths,
212            toplevel: std::mem::take(&mut self.best_path).into_simple(),
213        };
214    }
215
216    fn get_best_flops(&self) -> f64 {
217        self.best_flops
218    }
219
220    fn get_best_size(&self) -> f64 {
221        self.best_size
222    }
223
224    fn get_best_path(&self) -> &ContractionPath {
225        &self.best_path
226    }
227
228    fn get_best_replace_path(&self) -> ContractionPath {
229        ssa_replace_ordering(&self.best_path)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    use rustc_hash::FxHashMap;
238
239    use crate::contractionpath::paths::{CostType, FindPath};
240    use crate::path;
241    use crate::tensornetwork::tensor::Tensor;
242
243    fn setup_simple() -> Tensor {
244        let bond_dims =
245            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
246        Tensor::new_composite(vec![
247            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
248            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
249            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
250        ])
251    }
252
253    fn setup_complex() -> Tensor {
254        let bond_dims = FxHashMap::from_iter([
255            (0, 27),
256            (1, 18),
257            (2, 12),
258            (3, 15),
259            (4, 5),
260            (5, 3),
261            (6, 18),
262            (7, 22),
263            (8, 45),
264            (9, 65),
265            (10, 5),
266            (11, 17),
267        ]);
268        Tensor::new_composite(vec![
269            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
270            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
271            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
272            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
273            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
274            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
275        ])
276    }
277
278    #[test]
279    fn test_contract_order_simple() {
280        let tn = setup_simple();
281        let mut opt = BranchBound::new(&tn, None, 20., CostType::Flops);
282        opt.find_path();
283
284        assert_eq!(opt.best_flops, 4540.);
285        assert_eq!(opt.best_size, 538.);
286        assert_eq!(opt.get_best_path(), &path![(1, 0), (2, 3)]);
287        assert_eq!(opt.get_best_replace_path(), path![(1, 0), (2, 1)]);
288    }
289
290    #[test]
291    fn test_contract_order_complex() {
292        let tn = setup_complex();
293        let mut opt = BranchBound::new(&tn, None, 20., CostType::Flops);
294        opt.find_path();
295
296        assert_eq!(opt.best_flops, 2654474.);
297        assert_eq!(opt.best_size, 89478.);
298        assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
299        assert_eq!(
300            opt.get_best_replace_path(),
301            path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
302        );
303    }
304}