1use std::collections::BinaryHeap;
2
3use itertools::Itertools;
4use rustc_hash::FxHashMap;
5
6use crate::{
7 contractionpath::{
8 candidates::Candidate,
9 contraction_cost::{contract_cost_tensors, contract_size_tensors},
10 paths::{CostType, FindPath},
11 ssa_ordering, ssa_replace_ordering, ContractionPath,
12 },
13 tensornetwork::tensor::Tensor,
14 utils::traits::HashMapInsertNew,
15};
16
17pub struct BranchBound<'a> {
19 tn: &'a Tensor,
20 nbranch: Option<usize>,
21 cutoff_flops_factor: f64,
22 minimize: CostType,
23 best_flops: f64,
24 best_size: f64,
25 best_path: ContractionPath,
26 best_progress: FxHashMap<usize, f64>,
27 result_cache: FxHashMap<(usize, usize), usize>,
28 flop_cache: FxHashMap<usize, f64>,
29 size_cache: FxHashMap<usize, f64>,
30 tensor_cache: FxHashMap<usize, Tensor>,
31}
32
33impl<'a> BranchBound<'a> {
34 pub fn new(
35 tn: &'a Tensor,
36 nbranch: Option<usize>,
37 cutoff_flops_factor: f64,
38 minimize: CostType,
39 ) -> Self {
40 Self {
41 tn,
42 nbranch,
43 cutoff_flops_factor,
44 minimize,
45 best_flops: f64::INFINITY,
46 best_size: f64::INFINITY,
47 best_path: ContractionPath::default(),
48 best_progress: FxHashMap::default(),
49 result_cache: FxHashMap::default(),
50 flop_cache: FxHashMap::default(),
51 size_cache: FxHashMap::default(),
52 tensor_cache: FxHashMap::default(),
53 }
54 }
55
56 fn assess_candidate(
57 &mut self,
58 mut i: usize,
59 mut j: usize,
60 flops: f64,
61 size: f64,
62 remaining_len: usize,
63 ) -> Option<Candidate> {
64 let flops_12: f64;
65 let size_12: f64;
66 let k12: usize;
67 let k12_tensor: Tensor;
68 let mut current_flops = flops;
69 let mut current_size = size;
70 if self.tensor_cache[&j].size() > self.tensor_cache[&i].size() {
72 (i, j) = (j, i);
73 }
74
75 if self.result_cache.contains_key(&(i, j)) {
76 k12 = self.result_cache[&(i, j)];
77 flops_12 = self.flop_cache[&k12];
78 size_12 = self.size_cache[&k12];
79 } else {
80 k12 = self.tensor_cache.len();
81 flops_12 = contract_cost_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
82 size_12 = contract_size_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
83 k12_tensor = &self.tensor_cache[&i] ^ &self.tensor_cache[&j];
84
85 self.result_cache.entry((i, j)).or_insert_with(|| k12);
86 self.flop_cache.entry(k12).or_insert_with(|| flops_12);
87 self.size_cache.entry(k12).or_insert_with(|| size_12);
88 self.tensor_cache.insert_new(k12, k12_tensor);
89 }
90 current_flops += flops_12;
91 current_size = current_size.max(size_12);
92
93 if current_flops > self.best_flops && current_size > self.best_size {
94 return None;
95 }
96 let best_flops = *self
97 .best_progress
98 .entry(remaining_len)
99 .or_insert(current_flops);
100
101 if current_flops < best_flops {
102 self.best_progress.insert(remaining_len, current_flops);
103 } else if current_flops > self.cutoff_flops_factor * best_flops {
104 return None;
105 }
106
107 Some(Candidate {
108 flop_cost: current_flops,
109 size_cost: current_size,
110 parent_ids: (i, j),
111 child_id: k12,
112 })
113 }
114
115 fn branch_iterate(
120 &mut self,
121 path: &[(usize, usize, usize)],
122 remaining: &[usize],
123 flops: f64,
124 size: f64,
125 ) {
126 if remaining.len() == 1 {
127 match self.minimize {
128 CostType::Flops => {
129 if self.best_flops > flops {
130 self.best_flops = flops;
131 self.best_size = size;
132 self.best_path = ssa_ordering(path, self.tn.tensors().len());
133 }
134 }
135 CostType::Size => {
136 if self.best_size > size {
137 self.best_flops = flops;
138 self.best_size = size;
139 self.best_path = ssa_ordering(path, self.tn.tensors().len());
140 }
141 }
142 }
143 return;
144 }
145
146 let mut candidates = BinaryHeap::<Candidate>::new();
147 for i in remaining.iter().copied().combinations(2) {
148 let candidate = self.assess_candidate(i[0], i[1], flops, size, remaining.len());
149 if let Some(new_candidate) = candidate {
150 candidates.push(new_candidate);
151 }
152 }
153 let mut new_remaining;
154 let mut new_path: Vec<(usize, usize, usize)>;
155 let mut bi = 0;
156 while self.nbranch.is_none() || bi < self.nbranch.unwrap() {
157 bi += 1;
158 let Some(Candidate {
159 flop_cost,
160 size_cost,
161 parent_ids,
162 child_id,
163 }) = candidates.pop()
164 else {
165 break;
166 };
167 new_remaining = remaining.to_vec();
168 new_remaining.retain(|e| *e != parent_ids.0 && *e != parent_ids.1);
169 new_remaining.push(child_id);
170 new_path = path.to_vec();
171 new_path.push((parent_ids.0, parent_ids.1, child_id));
172 self.branch_iterate(&new_path, &new_remaining, flop_cost, size_cost);
173 }
174 }
175}
176
177impl FindPath for BranchBound<'_> {
178 fn find_path(&mut self) {
179 if self.tn.is_leaf() {
180 return;
181 }
182 let tensors = self.tn.tensors().clone();
183 self.flop_cache.clear();
184 self.size_cache.clear();
185 self.result_cache.clear();
186 self.tensor_cache.clear();
187 let mut nested_paths = FxHashMap::default();
188 for (index, mut tensor) in tensors.into_iter().enumerate() {
190 if !tensor.tensors().is_empty() && tensor.legs().is_empty() {
192 let mut bb = BranchBound::new(
193 &tensor,
194 self.nbranch,
195 self.cutoff_flops_factor,
196 self.minimize,
197 );
198 bb.find_path();
199 nested_paths.insert(index, bb.get_best_path().clone());
200 tensor = tensor.external_tensor();
201 }
202 self.size_cache
203 .entry(index)
204 .or_insert_with(|| tensor.size());
205
206 self.tensor_cache.insert_new(index, tensor);
207 }
208 let remaining = (0..self.tn.tensors().len()).collect_vec();
209 self.branch_iterate(&[], &remaining, 0f64, 0f64);
210 self.best_path = ContractionPath {
211 nested: nested_paths,
212 toplevel: std::mem::take(&mut self.best_path).into_simple(),
213 };
214 }
215
216 fn get_best_flops(&self) -> f64 {
217 self.best_flops
218 }
219
220 fn get_best_size(&self) -> f64 {
221 self.best_size
222 }
223
224 fn get_best_path(&self) -> &ContractionPath {
225 &self.best_path
226 }
227
228 fn get_best_replace_path(&self) -> ContractionPath {
229 ssa_replace_ordering(&self.best_path)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 use rustc_hash::FxHashMap;
238
239 use crate::contractionpath::paths::{CostType, FindPath};
240 use crate::path;
241 use crate::tensornetwork::tensor::Tensor;
242
243 fn setup_simple() -> Tensor {
244 let bond_dims =
245 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
246 Tensor::new_composite(vec![
247 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
248 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
249 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
250 ])
251 }
252
253 fn setup_complex() -> Tensor {
254 let bond_dims = FxHashMap::from_iter([
255 (0, 27),
256 (1, 18),
257 (2, 12),
258 (3, 15),
259 (4, 5),
260 (5, 3),
261 (6, 18),
262 (7, 22),
263 (8, 45),
264 (9, 65),
265 (10, 5),
266 (11, 17),
267 ]);
268 Tensor::new_composite(vec![
269 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
270 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
271 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
272 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
273 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
274 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
275 ])
276 }
277
278 #[test]
279 fn test_contract_order_simple() {
280 let tn = setup_simple();
281 let mut opt = BranchBound::new(&tn, None, 20., CostType::Flops);
282 opt.find_path();
283
284 assert_eq!(opt.best_flops, 4540.);
285 assert_eq!(opt.best_size, 538.);
286 assert_eq!(opt.get_best_path(), &path![(1, 0), (2, 3)]);
287 assert_eq!(opt.get_best_replace_path(), path![(1, 0), (2, 1)]);
288 }
289
290 #[test]
291 fn test_contract_order_complex() {
292 let tn = setup_complex();
293 let mut opt = BranchBound::new(&tn, None, 20., CostType::Flops);
294 opt.find_path();
295
296 assert_eq!(opt.best_flops, 2654474.);
297 assert_eq!(opt.best_size, 89478.);
298 assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
299 assert_eq!(
300 opt.get_best_replace_path(),
301 path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
302 );
303 }
304}