1use std::fmt;
2
3use itertools::Itertools;
4use rand::distr::Uniform;
5use rand::Rng;
6use rustc_hash::FxHashMap;
7
8use crate::contractionpath::contraction_cost::communication_path_cost;
9use crate::contractionpath::paths::cotengrust::{Cotengrust, OptMethod};
10use crate::contractionpath::paths::weighted_branchbound::WeightedBranchBound;
11use crate::contractionpath::paths::{CostType, FindPath};
12use crate::contractionpath::SimplePath;
13use crate::tensornetwork::partitioning::communication_partitioning;
14use crate::tensornetwork::partitioning::partition_config::PartitioningStrategy;
15use crate::tensornetwork::tensor::Tensor;
16
17#[derive(Debug, Copy, Clone)]
20pub enum CommunicationScheme {
21 Greedy,
23 RandomGreedy,
25 Bipartition,
27 BipartitionSweep,
29 WeightedBranchBound,
31 BranchBound,
33}
34
35impl fmt::Display for CommunicationScheme {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 let comm_str = match self {
38 Self::Greedy => "greedy",
39 Self::RandomGreedy => "random_greedy",
40 Self::Bipartition => "bipartition",
41 Self::BipartitionSweep => "bipartition_sweep",
42 Self::WeightedBranchBound => "weightedbranchbound",
43 Self::BranchBound => "branchbound",
44 };
45 f.write_str(comm_str)
46 }
47}
48
49impl CommunicationScheme {
50 pub(crate) fn communication_path<R>(
51 &self,
52 children_tensors: &[Tensor],
53 latency_map: &FxHashMap<usize, f64>,
54 rng: Option<&mut R>,
55 ) -> SimplePath
56 where
57 R: Rng,
58 {
59 match self {
60 Self::Greedy => greedy(children_tensors, latency_map),
61 Self::RandomGreedy => random_greedy(children_tensors),
62 Self::Bipartition => bipartition(children_tensors, latency_map),
63 Self::BipartitionSweep => {
64 let Some(rng) = rng else {
65 panic!("BipartitionSweep requires a random number generator")
66 };
67 bipartition_sweep(children_tensors, latency_map, rng)
68 }
69
70 Self::WeightedBranchBound => weighted_branchbound(children_tensors, latency_map),
71 Self::BranchBound => branchbound(children_tensors),
72 }
73 }
74}
75
76fn greedy(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
77 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
78 let mut opt = Cotengrust::new(&communication_tensors, OptMethod::Greedy);
79 opt.find_path();
80 opt.get_best_replace_path().into_simple()
81}
82
83fn bipartition(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
84 let children_tensors = children_tensors.iter().cloned().enumerate().collect_vec();
85 let imbalance = 0.03;
86 tensor_bipartition(&children_tensors, imbalance)
87}
88
89fn bipartition_sweep<R>(
90 children_tensors: &[Tensor],
91 latency_map: &FxHashMap<usize, f64>,
92 rng: &mut R,
93) -> SimplePath
94where
95 R: Rng,
96{
97 let tensors = children_tensors.iter().cloned().enumerate().collect_vec();
98 let mut best_flops = f64::INFINITY;
99 let mut best_path = vec![];
100 let partition_latencies = latency_map
101 .iter()
102 .sorted_by_key(|(k, _)| **k)
103 .map(|(_, v)| *v)
104 .collect::<Vec<_>>();
105 for _ in 0..20 {
106 let imbalance = rng.sample(Uniform::new(0.01, 0.5).unwrap());
107 let path = tensor_bipartition(&tensors, imbalance);
108 let (flops, _) = communication_path_cost(
109 children_tensors,
110 &path,
111 true,
112 true,
113 Some(&partition_latencies),
114 );
115 if flops < best_flops {
116 best_flops = flops;
117 best_path = path;
118 }
119 }
120 best_path
121}
122
123fn weighted_branchbound(
124 children_tensors: &[Tensor],
125 latency_map: &FxHashMap<usize, f64>,
126) -> SimplePath {
127 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
128
129 let mut opt = WeightedBranchBound::new(
130 &communication_tensors,
131 Some(10),
132 5.,
133 latency_map.clone(),
134 CostType::Flops,
135 );
136 opt.find_path();
137 opt.get_best_replace_path().into_simple()
138}
139
140fn branchbound(children_tensors: &[Tensor]) -> SimplePath {
141 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
142 let latency_map = (0..children_tensors.len()).map(|i| (i, 0.0)).collect();
143
144 let mut opt = WeightedBranchBound::new(
145 &communication_tensors,
146 Some(10),
147 5.,
148 latency_map,
149 CostType::Flops,
150 );
151 opt.find_path();
152 opt.get_best_replace_path().into_simple()
153}
154
155fn tensor_bipartition_recursive(
158 children_tensor: &[(usize, Tensor)],
159 imbalance: f64,
160) -> (usize, Tensor, SimplePath) {
161 let k = 2;
162 let min = true;
163
164 if children_tensor.len() == 1 {
166 return (
167 children_tensor[0].0,
168 children_tensor[0].1.clone(),
169 Vec::new(),
170 );
171 }
172
173 if children_tensor.len() == 2 {
175 let (t1, t2) = if children_tensor[1].1.size() > children_tensor[0].1.size() {
177 (children_tensor[1].0, children_tensor[0].0)
178 } else {
179 (children_tensor[0].0, children_tensor[1].0)
180 };
181 let tensor = &children_tensor[0].1 ^ &children_tensor[1].1;
182
183 return (t1, tensor, vec![(t1, t2)]);
184 }
185
186 let partitioning = communication_partitioning(
187 children_tensor,
188 k,
189 imbalance,
190 PartitioningStrategy::MinCut,
191 min,
192 );
193
194 let mut partition_iter = partitioning.iter();
195 let (children_1, children_2): (Vec<_>, Vec<_>) = children_tensor
196 .iter()
197 .cloned()
198 .partition(|_| partition_iter.next() == Some(&0));
199
200 let (id_1, t1, mut contraction_1) = tensor_bipartition_recursive(&children_1, imbalance);
201
202 let (id_2, t2, mut contraction_2) = tensor_bipartition_recursive(&children_2, imbalance);
203
204 let tensor = &t1 ^ &t2;
205
206 contraction_1.append(&mut contraction_2);
207 let (id_1, id_2) = if t2.size() > t1.size() {
208 (id_2, id_1)
209 } else {
210 (id_1, id_2)
211 };
212
213 contraction_1.push((id_1, id_2));
214 (id_1, tensor, contraction_1)
215}
216
217fn tensor_bipartition(children_tensor: &[(usize, Tensor)], imbalance: f64) -> SimplePath {
220 let (_, _, contraction_path) = tensor_bipartition_recursive(children_tensor, imbalance);
221 contraction_path
222}
223
224fn random_greedy(children_tensors: &[Tensor]) -> SimplePath {
225 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
226
227 let mut opt = Cotengrust::new(&communication_tensors, OptMethod::RandomGreedy(100));
228 opt.find_path();
229 opt.get_best_replace_path().into_simple()
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 use itertools::Itertools;
237 use rustc_hash::FxHashMap;
238
239 use crate::{
240 contractionpath::contraction_cost::communication_path_cost, tensornetwork::tensor::Tensor,
241 };
242
243 fn setup_simple_partition_data() -> FxHashMap<usize, f64> {
244 FxHashMap::from_iter([(0, 40.), (1, 30.), (2, 50.)])
245 }
246
247 fn setup_simple() -> Vec<Tensor> {
251 let bond_dims =
252 FxHashMap::from_iter([(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2)]);
253
254 let tensor0 = Tensor::new_from_map(vec![3, 4, 5], &bond_dims);
255 let tensor1 = Tensor::new_from_map(vec![0, 1, 3, 4], &bond_dims);
256 let tensor2 = Tensor::new_from_map(vec![0, 1, 2, 5, 6], &bond_dims);
257 vec![tensor0, tensor1, tensor2]
258 }
259
260 #[test]
261 fn test_greedy_communication() {
262 let tensors = setup_simple();
263 let latency_map = setup_simple_partition_data();
264 let communication_scheme = greedy(&tensors, &latency_map);
265
266 assert_eq!(&communication_scheme, &[(0, 1), (0, 2)]);
267 let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
268 let (flop_cost, mem_cost) = communication_path_cost(
269 &tensors,
270 &communication_scheme,
271 true,
272 true,
273 Some(&tensor_costs),
274 );
275 assert_eq!(flop_cost, 104.);
276 assert_eq!(mem_cost, 44.);
277 }
278
279 #[test]
280 fn test_weighted_communication() {
281 let tensors = setup_simple();
282 let latency_map = setup_simple_partition_data();
283
284 let communication_scheme = weighted_branchbound(&tensors, &latency_map);
285
286 assert_eq!(&communication_scheme, &[(1, 0), (2, 1)]);
287 let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
292 let (flop_cost, mem_cost) = communication_path_cost(
293 &tensors,
294 &communication_scheme,
295 true,
296 true,
297 Some(&tensor_costs),
298 );
299
300 assert_eq!(flop_cost, 104.);
301 assert_eq!(mem_cost, 44.);
302 }
303
304 #[test]
305 fn test_bi_partition_communication() {
306 let tensors = setup_simple();
307 let latency_map = setup_simple_partition_data();
308
309 let communication_scheme = bipartition(&tensors, &latency_map);
310
311 assert_eq!(&communication_scheme, &[(2, 1), (2, 0)]);
312
313 let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
318 let (flop_cost, mem_cost) = communication_path_cost(
319 &tensors,
320 &communication_scheme,
321 true,
322 true,
323 Some(&tensor_costs),
324 );
325
326 assert_eq!(flop_cost, 210.);
327 assert_eq!(mem_cost, 80.);
328 }
329}