1use std::collections::BinaryHeap;
2
3use itertools::Itertools;
4use rustc_hash::FxHashMap;
5
6use crate::{
7 contractionpath::{
8 candidates::Candidate,
9 contraction_cost::{contract_op_cost_tensors, contract_size_tensors},
10 paths::{CostType, FindPath},
11 ssa_ordering, ssa_replace_ordering, ContractionPath,
12 },
13 tensornetwork::tensor::Tensor,
14 utils::traits::HashMapInsertNew,
15};
16
17pub struct WeightedBranchBound<'a> {
19 tn: &'a Tensor,
20 nbranch: Option<usize>,
21 cutoff_flops_factor: f64,
22 minimize: CostType,
23 best_flops: f64,
24 best_size: f64,
25 best_path: ContractionPath,
26 best_progress: FxHashMap<usize, f64>,
27 largest_latency: f64,
28 result_cache: FxHashMap<(usize, usize), (usize, f64, f64)>,
29 comm_cache: FxHashMap<usize, f64>,
30 tensor_cache: FxHashMap<usize, Tensor>,
31}
32
33impl<'a> WeightedBranchBound<'a> {
34 pub fn new(
35 tn: &'a Tensor,
36 nbranch: Option<usize>,
37 cutoff_flops_factor: f64,
38 latency_map: FxHashMap<usize, f64>,
39 minimize: CostType,
40 ) -> Self {
41 Self {
42 tn,
43 nbranch,
44 cutoff_flops_factor,
45 minimize,
46 best_flops: f64::INFINITY,
47 best_size: f64::INFINITY,
48 best_path: ContractionPath::default(),
49 best_progress: FxHashMap::default(),
50 largest_latency: Default::default(),
51 result_cache: FxHashMap::default(),
52 comm_cache: latency_map,
53 tensor_cache: FxHashMap::default(),
54 }
55 }
56
57 fn assess_candidate(
58 &mut self,
59 mut i: usize,
60 mut j: usize,
61 size: f64,
62 remaining_len: usize,
63 ) -> Option<Candidate> {
64 if self.tensor_cache[&j].size() > self.tensor_cache[&i].size() {
65 (i, j) = (j, i);
66 }
67
68 let &mut (k12, flops_12, size_12) = self.result_cache.entry((i, j)).or_insert_with(|| {
69 let k12 = self.tensor_cache.len();
70 let flops_12 = contract_op_cost_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
71 let size_12 = contract_size_tensors(&self.tensor_cache[&i], &self.tensor_cache[&j]);
72 let k12_tensor = &self.tensor_cache[&i] ^ &self.tensor_cache[&j];
73 self.tensor_cache.insert_new(k12, k12_tensor);
74 (k12, flops_12, size_12)
75 });
76
77 let current_flops = if let Some(total_flops) = self.comm_cache.get(&k12) {
78 *total_flops
79 } else {
80 let total_flops = flops_12 + self.comm_cache[&i].max(self.comm_cache[&j]);
81 self.comm_cache.insert(k12, total_flops);
82 total_flops
83 };
84 let current_size = size.max(size_12);
85
86 if current_flops > self.best_flops && current_size > self.best_size {
87 return None;
88 }
89 let best_flops = *self
90 .best_progress
91 .entry(remaining_len)
92 .or_insert(current_flops);
93
94 if current_flops < best_flops {
95 self.best_progress.insert(remaining_len, current_flops);
96 } else if current_flops > (self.cutoff_flops_factor * best_flops + self.largest_latency) {
97 return None;
98 }
99
100 Some(Candidate {
101 flop_cost: current_flops,
102 size_cost: current_size,
103 parent_ids: (i, j),
104 child_id: k12,
105 })
106 }
107
108 fn branch_iterate(
113 &mut self,
114 path: &[(usize, usize, usize)],
115 remaining: &[usize],
116 flops: f64,
117 size: f64,
118 ) {
119 if remaining.len() == 1 {
120 match self.minimize {
121 CostType::Flops => {
122 if self.best_flops > flops {
123 self.best_flops = flops;
124 self.best_size = size;
125 self.best_path = ssa_ordering(path, self.tn.tensors().len());
126 }
127 }
128 CostType::Size => {
129 if self.best_size > size {
130 self.best_flops = flops;
131 self.best_size = size;
132 self.best_path = ssa_ordering(path, self.tn.tensors().len());
133 }
134 }
135 }
136 return;
137 }
138
139 let mut candidates = BinaryHeap::with_capacity(remaining.len() * (remaining.len() - 1) / 2);
140 for pair in remaining.iter().copied().combinations(2) {
141 let candidate = self.assess_candidate(pair[0], pair[1], size, remaining.len());
142 if let Some(new_candidate) = candidate {
143 candidates.push(new_candidate);
144 }
145 }
146 let mut candidates = candidates.into_sorted_vec();
147 if let Some(limit) = self.nbranch {
148 candidates.truncate(limit);
149 }
150
151 let mut new_path = Vec::with_capacity(path.len() + 1);
152 new_path.extend_from_slice(path);
153
154 for candidate in candidates.into_iter().rev() {
155 let Candidate {
156 flop_cost,
157 size_cost,
158 parent_ids,
159 child_id,
160 } = candidate;
161 let mut new_remaining = remaining.to_vec();
162 new_remaining.retain(|e| *e != parent_ids.0 && *e != parent_ids.1);
163 new_remaining.push(child_id);
164 new_path.push((parent_ids.0, parent_ids.1, child_id));
165 self.branch_iterate(&new_path, &new_remaining, flop_cost, size_cost);
166 new_path.pop();
167 }
168 }
169}
170
171impl FindPath for WeightedBranchBound<'_> {
172 fn find_path(&mut self) {
173 if self.tn.is_leaf() {
174 return;
175 }
176 let tensors = self.tn.tensors().clone();
177 self.result_cache.clear();
178 self.tensor_cache.clear();
179 self.largest_latency = *self
180 .comm_cache
181 .iter()
182 .max_by(|a, b| a.1.partial_cmp(b.1).expect("Tried to compare NaN"))
183 .unwrap()
184 .1;
185 let mut nested_paths = FxHashMap::default();
186 for (index, mut tensor) in tensors.into_iter().enumerate() {
188 if tensor.is_composite() && tensor.legs().is_empty() {
190 let mut bb = WeightedBranchBound::new(
191 &tensor,
192 self.nbranch,
193 self.cutoff_flops_factor,
194 self.comm_cache.clone(),
195 self.minimize,
196 );
197 bb.find_path();
198 nested_paths.insert(index, bb.get_best_path().clone());
199 tensor = tensor.external_tensor();
200 }
201 self.tensor_cache.insert_new(index, tensor);
202 }
203 let remaining = (0..self.tn.tensors().len()).collect_vec();
204 self.branch_iterate(&[], &remaining, 0f64, 0f64);
205 self.best_path = ContractionPath {
206 nested: nested_paths,
207 toplevel: std::mem::take(&mut self.best_path).into_simple(),
208 };
209 }
210
211 fn get_best_flops(&self) -> f64 {
212 self.best_flops
213 }
214
215 fn get_best_size(&self) -> f64 {
216 self.best_size
217 }
218
219 fn get_best_path(&self) -> &ContractionPath {
220 &self.best_path
221 }
222
223 fn get_best_replace_path(&self) -> ContractionPath {
224 ssa_replace_ordering(&self.best_path)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 use rustc_hash::FxHashMap;
233
234 use crate::contractionpath::paths::CostType;
235 use crate::contractionpath::paths::FindPath;
236 use crate::path;
237 use crate::tensornetwork::tensor::Tensor;
238
239 fn setup_simple() -> (Tensor, FxHashMap<usize, f64>) {
240 let bond_dims =
241 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
242 (
243 Tensor::new_composite(vec![
244 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
245 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
246 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
247 ]),
248 FxHashMap::from_iter([(0, 20.), (1, 40.), (2, 85.)]),
249 )
250 }
251
252 fn setup_complex() -> (Tensor, FxHashMap<usize, f64>) {
253 let bond_dims = FxHashMap::from_iter([
254 (0, 27),
255 (1, 18),
256 (2, 12),
257 (3, 15),
258 (4, 5),
259 (5, 3),
260 (6, 18),
261 (7, 22),
262 (8, 45),
263 (9, 65),
264 (10, 5),
265 (11, 17),
266 ]);
267 (
268 Tensor::new_composite(vec![
269 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
270 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
271 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
272 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
273 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
274 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
275 ]),
276 FxHashMap::from_iter([(0, 120.), (1, 0.), (2, 15.), (3, 15.), (4, 85.), (5, 15.)]),
277 )
278 }
279
280 #[test]
281 fn test_contract_order_simple() {
282 let (tn, latency_costs) = setup_simple();
283 let mut opt = WeightedBranchBound::new(&tn, None, 20., latency_costs, CostType::Flops);
284 opt.find_path();
285
286 assert_eq!(opt.best_flops, 640.);
287 assert_eq!(opt.best_size, 538.);
288 assert_eq!(opt.get_best_path(), &path![(1, 0), (2, 3)]);
289 assert_eq!(opt.get_best_replace_path(), path![(1, 0), (2, 1)]);
290 }
291
292 #[test]
293 fn test_contract_order_complex() {
294 let (tn, latency_costs) = setup_complex();
295 let mut opt = WeightedBranchBound::new(&tn, None, 20., latency_costs, CostType::Flops);
296 opt.find_path();
297
298 assert_eq!(opt.best_flops, 265230.);
299 assert_eq!(opt.best_size, 89478.);
300 assert_eq!(opt.best_path, path![(3, 4), (2, 6), (1, 5), (0, 8), (7, 9)]);
301 assert_eq!(
302 opt.get_best_replace_path(),
303 path![(3, 4), (2, 3), (1, 5), (0, 1), (2, 0)]
304 );
305 }
306}