1use std::fmt;
2
3use itertools::Itertools;
4use rand::distr::Uniform;
5use rand::Rng;
6use rustc_hash::FxHashMap;
7
8use crate::contractionpath::contraction_cost::communication_path_cost;
9use crate::contractionpath::paths::cotengrust::{Cotengrust, OptMethod};
10use crate::contractionpath::paths::weighted_branchbound::WeightedBranchBound;
11use crate::contractionpath::paths::{CostType, FindPath};
12use crate::contractionpath::SimplePath;
13use crate::tensornetwork::partitioning::{communication_partitioning, PartitioningStrategy};
14use crate::tensornetwork::tensor::Tensor;
15
16#[derive(Debug, Copy, Clone)]
19pub enum CommunicationScheme {
20 Greedy,
22 RandomGreedy,
24 Bipartition,
26 BipartitionSweep,
28 WeightedBranchBound,
30 BranchBound,
32}
33
34impl fmt::Display for CommunicationScheme {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 let comm_str = match self {
37 Self::Greedy => "greedy",
38 Self::RandomGreedy => "random_greedy",
39 Self::Bipartition => "bipartition",
40 Self::BipartitionSweep => "bipartition_sweep",
41 Self::WeightedBranchBound => "weightedbranchbound",
42 Self::BranchBound => "branchbound",
43 };
44 f.write_str(comm_str)
45 }
46}
47
48impl CommunicationScheme {
49 pub(crate) fn communication_path<R>(
50 &self,
51 children_tensors: &[Tensor],
52 latency_map: &FxHashMap<usize, f64>,
53 rng: Option<&mut R>,
54 ) -> SimplePath
55 where
56 R: Rng,
57 {
58 match self {
59 Self::Greedy => greedy(children_tensors, latency_map),
60 Self::RandomGreedy => random_greedy(children_tensors),
61 Self::Bipartition => bipartition(children_tensors, latency_map),
62 Self::BipartitionSweep => {
63 let Some(rng) = rng else {
64 panic!("BipartitionSweep requires a random number generator")
65 };
66 bipartition_sweep(children_tensors, latency_map, rng)
67 }
68
69 Self::WeightedBranchBound => weighted_branchbound(children_tensors, latency_map),
70 Self::BranchBound => branchbound(children_tensors),
71 }
72 }
73}
74
75fn greedy(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
76 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
77 let mut opt = Cotengrust::new(&communication_tensors, OptMethod::Greedy);
78 opt.find_path();
79 opt.get_best_replace_path().into_simple()
80}
81
82fn bipartition(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
83 let children_tensors = children_tensors.iter().cloned().enumerate().collect_vec();
84 let imbalance = 0.03;
85 tensor_bipartition(&children_tensors, imbalance)
86}
87
88fn bipartition_sweep<R>(
89 children_tensors: &[Tensor],
90 latency_map: &FxHashMap<usize, f64>,
91 rng: &mut R,
92) -> SimplePath
93where
94 R: Rng,
95{
96 let tensors = children_tensors.iter().cloned().enumerate().collect_vec();
97 let mut best_flops = f64::INFINITY;
98 let mut best_path = vec![];
99 let partition_latencies = latency_map
100 .iter()
101 .sorted_by_key(|(k, _)| **k)
102 .map(|(_, v)| *v)
103 .collect::<Vec<_>>();
104 for _ in 0..20 {
105 let imbalance = rng.sample(Uniform::new(0.01, 0.5).unwrap());
106 let path = tensor_bipartition(&tensors, imbalance);
107 let (flops, _) = communication_path_cost(
108 children_tensors,
109 &path,
110 true,
111 true,
112 Some(&partition_latencies),
113 );
114 if flops < best_flops {
115 best_flops = flops;
116 best_path = path;
117 }
118 }
119 best_path
120}
121
122fn weighted_branchbound(
123 children_tensors: &[Tensor],
124 latency_map: &FxHashMap<usize, f64>,
125) -> SimplePath {
126 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
127
128 let mut opt = WeightedBranchBound::new(
129 &communication_tensors,
130 Some(10),
131 5.,
132 latency_map.clone(),
133 CostType::Flops,
134 );
135 opt.find_path();
136 opt.get_best_replace_path().into_simple()
137}
138
139fn branchbound(children_tensors: &[Tensor]) -> SimplePath {
140 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
141 let latency_map = (0..children_tensors.len()).map(|i| (i, 0.0)).collect();
142
143 let mut opt = WeightedBranchBound::new(
144 &communication_tensors,
145 Some(10),
146 5.,
147 latency_map,
148 CostType::Flops,
149 );
150 opt.find_path();
151 opt.get_best_replace_path().into_simple()
152}
153
154fn tensor_bipartition_recursive(
157 children_tensor: &[(usize, Tensor)],
158 imbalance: f64,
159) -> (usize, Tensor, SimplePath) {
160 let k = 2;
161 let min = true;
162
163 if children_tensor.len() == 1 {
165 return (
166 children_tensor[0].0,
167 children_tensor[0].1.clone(),
168 Vec::new(),
169 );
170 }
171
172 if children_tensor.len() == 2 {
174 let (t1, t2) = if children_tensor[1].1.size() > children_tensor[0].1.size() {
176 (children_tensor[1].0, children_tensor[0].0)
177 } else {
178 (children_tensor[0].0, children_tensor[1].0)
179 };
180 let tensor = &children_tensor[0].1 ^ &children_tensor[1].1;
181
182 return (t1, tensor, vec![(t1, t2)]);
183 }
184
185 let partitioning = communication_partitioning(
186 children_tensor,
187 k,
188 imbalance,
189 PartitioningStrategy::MinCut,
190 min,
191 );
192
193 let mut partition_iter = partitioning.iter();
194 let (children_1, children_2): (Vec<_>, Vec<_>) = children_tensor
195 .iter()
196 .cloned()
197 .partition(|_| partition_iter.next() == Some(&0));
198
199 let (id_1, t1, mut contraction_1) = tensor_bipartition_recursive(&children_1, imbalance);
200
201 let (id_2, t2, mut contraction_2) = tensor_bipartition_recursive(&children_2, imbalance);
202
203 let tensor = &t1 ^ &t2;
204
205 contraction_1.append(&mut contraction_2);
206 let (id_1, id_2) = if t2.size() > t1.size() {
207 (id_2, id_1)
208 } else {
209 (id_1, id_2)
210 };
211
212 contraction_1.push((id_1, id_2));
213 (id_1, tensor, contraction_1)
214}
215
216fn tensor_bipartition(children_tensor: &[(usize, Tensor)], imbalance: f64) -> SimplePath {
219 let (_, _, contraction_path) = tensor_bipartition_recursive(children_tensor, imbalance);
220 contraction_path
221}
222
223fn random_greedy(children_tensors: &[Tensor]) -> SimplePath {
224 let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
225
226 let mut opt = Cotengrust::new(&communication_tensors, OptMethod::RandomGreedy(100));
227 opt.find_path();
228 opt.get_best_replace_path().into_simple()
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 use itertools::Itertools;
236 use rustc_hash::FxHashMap;
237
238 use crate::{
239 contractionpath::contraction_cost::communication_path_cost, tensornetwork::tensor::Tensor,
240 };
241
242 fn setup_simple_partition_data() -> FxHashMap<usize, f64> {
243 FxHashMap::from_iter([(0, 40.), (1, 30.), (2, 50.)])
244 }
245
246 fn setup_simple() -> Vec<Tensor> {
250 let bond_dims =
251 FxHashMap::from_iter([(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2)]);
252
253 let tensor0 = Tensor::new_from_map(vec![3, 4, 5], &bond_dims);
254 let tensor1 = Tensor::new_from_map(vec![0, 1, 3, 4], &bond_dims);
255 let tensor2 = Tensor::new_from_map(vec![0, 1, 2, 5, 6], &bond_dims);
256 vec![tensor0, tensor1, tensor2]
257 }
258
259 #[test]
260 fn test_greedy_communication() {
261 let tensors = setup_simple();
262 let latency_map = setup_simple_partition_data();
263 let communication_scheme = greedy(&tensors, &latency_map);
264
265 assert_eq!(&communication_scheme, &[(0, 1), (0, 2)]);
266 let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
267 let (flop_cost, mem_cost) = communication_path_cost(
268 &tensors,
269 &communication_scheme,
270 true,
271 true,
272 Some(&tensor_costs),
273 );
274 assert_eq!(flop_cost, 104.);
275 assert_eq!(mem_cost, 44.);
276 }
277
278 #[test]
279 fn test_weighted_communication() {
280 let tensors = setup_simple();
281 let latency_map = setup_simple_partition_data();
282
283 let communication_scheme = weighted_branchbound(&tensors, &latency_map);
284
285 assert_eq!(&communication_scheme, &[(1, 0), (2, 1)]);
286 let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
291 let (flop_cost, mem_cost) = communication_path_cost(
292 &tensors,
293 &communication_scheme,
294 true,
295 true,
296 Some(&tensor_costs),
297 );
298
299 assert_eq!(flop_cost, 104.);
300 assert_eq!(mem_cost, 44.);
301 }
302
303 #[test]
304 fn test_bi_partition_communication() {
305 let tensors = setup_simple();
306 let latency_map = setup_simple_partition_data();
307
308 let communication_scheme = bipartition(&tensors, &latency_map);
309
310 assert_eq!(&communication_scheme, &[(2, 1), (2, 0)]);
311
312 let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
317 let (flop_cost, mem_cost) = communication_path_cost(
318 &tensors,
319 &communication_scheme,
320 true,
321 true,
322 Some(&tensor_costs),
323 );
324
325 assert_eq!(flop_cost, 210.);
326 assert_eq!(mem_cost, 80.);
327 }
328}