tnc/contractionpath/paths/
weighted_branchbound.rs

1use std::collections::BinaryHeap;
2
3use itertools::Itertools;
4use rustc_hash::FxHashMap;
5
6use crate::{
7    contractionpath::{
8        candidates::Candidate,
9        contraction_cost::{contract_op_cost_tensors, contract_size_tensors},
10        paths::{CostType, FindPath},
11        ssa_ordering, ssa_replace_ordering, ContractionPath,
12    },
13    tensornetwork::tensor::Tensor,
14    utils::traits::HashMapInsertNew,
15};
16
17/// A struct with an [`FindPath`] implementation that explores possible pair contractions in a depth-first manner.
18pub struct WeightedBranchBound<'a> {
19    tn: &'a Tensor,
20    nbranch: Option<usize>,
21    cutoff_flops_factor: f64,
22    minimize: CostType,
23    best_flops: f64,
24    best_size: f64,
25    best_path: ContractionPath,
26    best_progress: FxHashMap<usize, f64>,
27    largest_latency: f64,
28    result_cache: FxHashMap<(usize, usize), (usize, f64, f64)>,
29    comm_cache: FxHashMap<usize, f64>,
30    tensor_cache: FxHashMap<usize, Tensor>,
31}
32
33impl<'a> WeightedBranchBound<'a> {
34    pub fn new(
35        tn: &'a Tensor,
36        nbranch: Option<usize>,
37        cutoff_flops_factor: f64,
38        latency_map: FxHashMap<usize, f64>,
39        minimize: CostType,
40    ) -> Self {
41        Self {
42            tn,
43            nbranch,
44            cutoff_flops_factor,
45            minimize,
46            best_flops: f64::INFINITY,
47            best_size: f64::INFINITY,
48            best_path: ContractionPath::default(),
49            best_progress: FxHashMap::default(),
50            largest_latency: Default::default(),
51            result_cache: FxHashMap::default(),
52            comm_cache: latency_map,
53            tensor_cache: FxHashMap::default(),
54        }
55    }
56
57    fn assess_candidate(
58        &mut self,
59        mut i: usize,
60        mut j: usize,
61        size: f64,
62        remaining_len: usize,
63    ) -> Option<Candidate> {
64        if self.tensor_cache[&j].size() > self.tensor_cache[&i].size() {
65            (i, j) = (j, i);
66        }
67
68        let &mut (k12, flops_12, size_12) = self.result_cache.entry((i, j)).or_insert_with(|| {
69            let k12 = self.tensor_cache.len();
70            let flops_12 = contract_op_cost_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
71            let size_12 = contract_size_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
72            let k12_tensor = &self.tensor_cache[&i] ^ &self.tensor_cache[&j];
73            self.tensor_cache.insert_new(k12, k12_tensor);
74            (k12, flops_12, size_12)
75        });
76
77        let current_flops = if let Some(total_flops) = self.comm_cache.get(&k12) {
78            *total_flops
79        } else {
80            let total_flops = flops_12 + self.comm_cache[&i].max(self.comm_cache[&j]);
81            self.comm_cache.insert(k12, total_flops);
82            total_flops
83        };
84        let current_size = size.max(size_12);
85
86        if current_flops > self.best_flops && current_size > self.best_size {
87            return None;
88        }
89        let best_flops = *self
90            .best_progress
91            .entry(remaining_len)
92            .or_insert(current_flops);
93
94        if current_flops < best_flops {
95            self.best_progress.insert(remaining_len, current_flops);
96        } else if current_flops > (self.cutoff_flops_factor * best_flops + self.largest_latency) {
97            return None;
98        }
99
100        Some(Candidate {
101            flop_cost: current_flops,
102            size_cost: current_size,
103            parent_ids: (i, j),
104            child_id: k12,
105        })
106    }
107
108    /// Explores possible pair contractions in a depth-first
109    /// recursive manner like the `optimal` approach, but with extra heuristic early pruning of branches
110    /// as well sieving by `memory_limit` and the best path found so far. A rust implementation of
111    /// the Python based `opt_einsum` implementation. Found at <https://github.com/dgasmith/opt_einsum>.
112    fn branch_iterate(
113        &mut self,
114        path: &[(usize, usize, usize)],
115        remaining: &[usize],
116        flops: f64,
117        size: f64,
118    ) {
119        if remaining.len() == 1 {
120            match self.minimize {
121                CostType::Flops => {
122                    if self.best_flops > flops {
123                        self.best_flops = flops;
124                        self.best_size = size;
125                        self.best_path = ssa_ordering(path, self.tn.tensors().len());
126                    }
127                }
128                CostType::Size => {
129                    if self.best_size > size {
130                        self.best_flops = flops;
131                        self.best_size = size;
132                        self.best_path = ssa_ordering(path, self.tn.tensors().len());
133                    }
134                }
135            }
136            return;
137        }
138
139        let mut candidates = BinaryHeap::with_capacity(remaining.len() * (remaining.len() - 1) / 2);
140        for pair in remaining.iter().copied().combinations(2) {
141            let candidate = self.assess_candidate(pair[0], pair[1], size, remaining.len());
142            if let Some(new_candidate) = candidate {
143                candidates.push(new_candidate);
144            }
145        }
146        let mut candidates = candidates.into_sorted_vec();
147        if let Some(limit) = self.nbranch {
148            candidates.truncate(limit);
149        }
150
151        let mut new_path = Vec::with_capacity(path.len() + 1);
152        new_path.extend_from_slice(path);
153
154        for candidate in candidates.into_iter().rev() {
155            let Candidate {
156                flop_cost,
157                size_cost,
158                parent_ids,
159                child_id,
160            } = candidate;
161            let mut new_remaining = remaining.to_vec();
162            new_remaining.retain(|e| *e != parent_ids.0 && *e != parent_ids.1);
163            new_remaining.push(child_id);
164            new_path.push((parent_ids.0, parent_ids.1, child_id));
165            self.branch_iterate(&new_path, &new_remaining, flop_cost, size_cost);
166            new_path.pop();
167        }
168    }
169}
170
171impl FindPath for WeightedBranchBound<'_> {
172    fn find_path(&mut self) {
173        if self.tn.is_leaf() {
174            return;
175        }
176        let tensors = self.tn.tensors().clone();
177        self.result_cache.clear();
178        self.tensor_cache.clear();
179        self.largest_latency = *self
180            .comm_cache
181            .iter()
182            .max_by(|a, b| a.1.partial_cmp(b.1).expect("Tried to compare NaN"))
183            .unwrap()
184            .1;
185        let mut nested_paths = FxHashMap::default();
186        // Get the initial space requirements for uncontracted tensors
187        for (index, mut tensor) in tensors.into_iter().enumerate() {
188            // Check that tensor has sub-tensors and doesn't have external legs set
189            if tensor.is_composite() && tensor.legs().is_empty() {
190                let mut bb = WeightedBranchBound::new(
191                    &tensor,
192                    self.nbranch,
193                    self.cutoff_flops_factor,
194                    self.comm_cache.clone(),
195                    self.minimize,
196                );
197                bb.find_path();
198                nested_paths.insert(index, bb.get_best_path().clone());
199                tensor = tensor.external_tensor();
200            }
201            self.tensor_cache.insert_new(index, tensor);
202        }
203        let remaining = (0..self.tn.tensors().len()).collect_vec();
204        self.branch_iterate(&[], &remaining, 0f64, 0f64);
205        self.best_path = ContractionPath {
206            nested: nested_paths,
207            toplevel: std::mem::take(&mut self.best_path).into_simple(),
208        };
209    }
210
211    fn get_best_flops(&self) -> f64 {
212        self.best_flops
213    }
214
215    fn get_best_size(&self) -> f64 {
216        self.best_size
217    }
218
219    fn get_best_path(&self) -> &ContractionPath {
220        &self.best_path
221    }
222
223    fn get_best_replace_path(&self) -> ContractionPath {
224        ssa_replace_ordering(&self.best_path)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    use rustc_hash::FxHashMap;
233
234    use crate::contractionpath::paths::CostType;
235    use crate::contractionpath::paths::FindPath;
236    use crate::path;
237    use crate::tensornetwork::tensor::Tensor;
238
239    fn setup_simple() -> (Tensor, FxHashMap<usize, f64>) {
240        let bond_dims =
241            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
242        (
243            Tensor::new_composite(vec![
244                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
245                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
246                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
247            ]),
248            FxHashMap::from_iter([(0, 20.), (1, 40.), (2, 85.)]),
249        )
250    }
251
252    fn setup_complex() -> (Tensor, FxHashMap<usize, f64>) {
253        let bond_dims = FxHashMap::from_iter([
254            (0, 27),
255            (1, 18),
256            (2, 12),
257            (3, 15),
258            (4, 5),
259            (5, 3),
260            (6, 18),
261            (7, 22),
262            (8, 45),
263            (9, 65),
264            (10, 5),
265            (11, 17),
266        ]);
267        (
268            Tensor::new_composite(vec![
269                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
270                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
271                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
272                Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
273                Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
274                Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
275            ]),
276            FxHashMap::from_iter([(0, 120.), (1, 0.), (2, 15.), (3, 15.), (4, 85.), (5, 15.)]),
277        )
278    }
279
280    #[test]
281    fn test_contract_order_simple() {
282        let (tn, latency_costs) = setup_simple();
283        let mut opt = WeightedBranchBound::new(&tn, None, 20., latency_costs, CostType::Flops);
284        opt.find_path();
285
286        assert_eq!(opt.best_flops, 640.);
287        assert_eq!(opt.best_size, 538.);
288        assert_eq!(opt.get_best_path(), &path![(1, 0), (2, 3)]);
289        assert_eq!(opt.get_best_replace_path(), path![(1, 0), (2, 1)]);
290    }
291
292    #[test]
293    fn test_contract_order_complex() {
294        let (tn, latency_costs) = setup_complex();
295        let mut opt = WeightedBranchBound::new(&tn, None, 20., latency_costs, CostType::Flops);
296        opt.find_path();
297
298        assert_eq!(opt.best_flops, 265230.);
299        assert_eq!(opt.best_size, 89478.);
300        assert_eq!(opt.best_path, path![(3, 4), (2, 6), (1, 5), (0, 8), (7, 9)]);
301        assert_eq!(
302            opt.get_best_replace_path(),
303            path![(3, 4), (2, 3), (1, 5), (0, 1), (2, 0)]
304        );
305    }
306}