tnc/contractionpath/
communication_schemes.rs

1use std::fmt;
2
3use itertools::Itertools;
4use rand::distr::Uniform;
5use rand::Rng;
6use rustc_hash::FxHashMap;
7
8use crate::contractionpath::contraction_cost::communication_path_cost;
9use crate::contractionpath::paths::cotengrust::{Cotengrust, OptMethod};
10use crate::contractionpath::paths::weighted_branchbound::WeightedBranchBound;
11use crate::contractionpath::paths::{CostType, FindPath};
12use crate::contractionpath::SimplePath;
13use crate::tensornetwork::partitioning::communication_partitioning;
14use crate::tensornetwork::partitioning::partition_config::PartitioningStrategy;
15use crate::tensornetwork::tensor::Tensor;
16
17/// The scheme used to find a contraction path for the final fan-in of tensors
18/// between MPI ranks.
19#[derive(Debug, Copy, Clone)]
20pub enum CommunicationScheme {
21    /// Uses Greedy scheme to find contraction path for communication
22    Greedy,
23    /// Uses a randomized greedy approach
24    RandomGreedy,
25    /// Uses repeated bipartitioning to identify communication path
26    Bipartition,
27    /// Uses repeated bipartitioning to identify communication path
28    BipartitionSweep,
29    /// Uses a filtered search that considered time to intermediate tensor
30    WeightedBranchBound,
31    /// Uses a filtered search
32    BranchBound,
33}
34
35impl fmt::Display for CommunicationScheme {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        let comm_str = match self {
38            Self::Greedy => "greedy",
39            Self::RandomGreedy => "random_greedy",
40            Self::Bipartition => "bipartition",
41            Self::BipartitionSweep => "bipartition_sweep",
42            Self::WeightedBranchBound => "weightedbranchbound",
43            Self::BranchBound => "branchbound",
44        };
45        f.write_str(comm_str)
46    }
47}
48
49impl CommunicationScheme {
50    pub(crate) fn communication_path<R>(
51        &self,
52        children_tensors: &[Tensor],
53        latency_map: &FxHashMap<usize, f64>,
54        rng: Option<&mut R>,
55    ) -> SimplePath
56    where
57        R: Rng,
58    {
59        match self {
60            Self::Greedy => greedy(children_tensors, latency_map),
61            Self::RandomGreedy => random_greedy(children_tensors),
62            Self::Bipartition => bipartition(children_tensors, latency_map),
63            Self::BipartitionSweep => {
64                let Some(rng) = rng else {
65                    panic!("BipartitionSweep requires a random number generator")
66                };
67                bipartition_sweep(children_tensors, latency_map, rng)
68            }
69
70            Self::WeightedBranchBound => weighted_branchbound(children_tensors, latency_map),
71            Self::BranchBound => branchbound(children_tensors),
72        }
73    }
74}
75
76fn greedy(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
77    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
78    let mut opt = Cotengrust::new(&communication_tensors, OptMethod::Greedy);
79    opt.find_path();
80    opt.get_best_replace_path().into_simple()
81}
82
83fn bipartition(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
84    let children_tensors = children_tensors.iter().cloned().enumerate().collect_vec();
85    let imbalance = 0.03;
86    tensor_bipartition(&children_tensors, imbalance)
87}
88
89fn bipartition_sweep<R>(
90    children_tensors: &[Tensor],
91    latency_map: &FxHashMap<usize, f64>,
92    rng: &mut R,
93) -> SimplePath
94where
95    R: Rng,
96{
97    let tensors = children_tensors.iter().cloned().enumerate().collect_vec();
98    let mut best_flops = f64::INFINITY;
99    let mut best_path = vec![];
100    let partition_latencies = latency_map
101        .iter()
102        .sorted_by_key(|(k, _)| **k)
103        .map(|(_, v)| *v)
104        .collect::<Vec<_>>();
105    for _ in 0..20 {
106        let imbalance = rng.sample(Uniform::new(0.01, 0.5).unwrap());
107        let path = tensor_bipartition(&tensors, imbalance);
108        let (flops, _) = communication_path_cost(
109            children_tensors,
110            &path,
111            true,
112            true,
113            Some(&partition_latencies),
114        );
115        if flops < best_flops {
116            best_flops = flops;
117            best_path = path;
118        }
119    }
120    best_path
121}
122
123fn weighted_branchbound(
124    children_tensors: &[Tensor],
125    latency_map: &FxHashMap<usize, f64>,
126) -> SimplePath {
127    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
128
129    let mut opt = WeightedBranchBound::new(
130        &communication_tensors,
131        Some(10),
132        5.,
133        latency_map.clone(),
134        CostType::Flops,
135    );
136    opt.find_path();
137    opt.get_best_replace_path().into_simple()
138}
139
140fn branchbound(children_tensors: &[Tensor]) -> SimplePath {
141    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
142    let latency_map = (0..children_tensors.len()).map(|i| (i, 0.0)).collect();
143
144    let mut opt = WeightedBranchBound::new(
145        &communication_tensors,
146        Some(10),
147        5.,
148        latency_map,
149        CostType::Flops,
150    );
151    opt.find_path();
152    opt.get_best_replace_path().into_simple()
153}
154
155/// Uses recursive bipartitioning to identify a communication scheme for final tensors
156/// Returns root id of subtree, parallel contraction cost as f64, resultant tensor and prior contraction sequence
157fn tensor_bipartition_recursive(
158    children_tensor: &[(usize, Tensor)],
159    imbalance: f64,
160) -> (usize, Tensor, SimplePath) {
161    let k = 2;
162    let min = true;
163
164    // Composite tensor contracts with a single leaf tensor
165    if children_tensor.len() == 1 {
166        return (
167            children_tensor[0].0,
168            children_tensor[0].1.clone(),
169            Vec::new(),
170        );
171    }
172
173    // Only occurs when there is a subset of 2 tensors
174    if children_tensor.len() == 2 {
175        // Always ensure that the larger tensor size is on the left.
176        let (t1, t2) = if children_tensor[1].1.size() > children_tensor[0].1.size() {
177            (children_tensor[1].0, children_tensor[0].0)
178        } else {
179            (children_tensor[0].0, children_tensor[1].0)
180        };
181        let tensor = &children_tensor[0].1 ^ &children_tensor[1].1;
182
183        return (t1, tensor, vec![(t1, t2)]);
184    }
185
186    let partitioning = communication_partitioning(
187        children_tensor,
188        k,
189        imbalance,
190        PartitioningStrategy::MinCut,
191        min,
192    );
193
194    let mut partition_iter = partitioning.iter();
195    let (children_1, children_2): (Vec<_>, Vec<_>) = children_tensor
196        .iter()
197        .cloned()
198        .partition(|_| partition_iter.next() == Some(&0));
199
200    let (id_1, t1, mut contraction_1) = tensor_bipartition_recursive(&children_1, imbalance);
201
202    let (id_2, t2, mut contraction_2) = tensor_bipartition_recursive(&children_2, imbalance);
203
204    let tensor = &t1 ^ &t2;
205
206    contraction_1.append(&mut contraction_2);
207    let (id_1, id_2) = if t2.size() > t1.size() {
208        (id_2, id_1)
209    } else {
210        (id_1, id_2)
211    };
212
213    contraction_1.push((id_1, id_2));
214    (id_1, tensor, contraction_1)
215}
216
217/// Repeatedly bipartitions tensor network to obtain communication scheme
218/// Assumes that all tensors contracted do so in parallel
219fn tensor_bipartition(children_tensor: &[(usize, Tensor)], imbalance: f64) -> SimplePath {
220    let (_, _, contraction_path) = tensor_bipartition_recursive(children_tensor, imbalance);
221    contraction_path
222}
223
224fn random_greedy(children_tensors: &[Tensor]) -> SimplePath {
225    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
226
227    let mut opt = Cotengrust::new(&communication_tensors, OptMethod::RandomGreedy(100));
228    opt.find_path();
229    opt.get_best_replace_path().into_simple()
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    use itertools::Itertools;
237    use rustc_hash::FxHashMap;
238
239    use crate::{
240        contractionpath::contraction_cost::communication_path_cost, tensornetwork::tensor::Tensor,
241    };
242
243    fn setup_simple_partition_data() -> FxHashMap<usize, f64> {
244        FxHashMap::from_iter([(0, 40.), (1, 30.), (2, 50.)])
245    }
246
247    /// Tensor ids in contraction tree included in variable name for easy tracking
248    /// This example prioritizes contracting tensor1 & tensor 2 using the greedy cost function
249    /// However, the partition cost of tensor 2 is very high, which makes contracting it later more attractive by reducing wait-time
250    fn setup_simple() -> Vec<Tensor> {
251        let bond_dims =
252            FxHashMap::from_iter([(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2)]);
253
254        let tensor0 = Tensor::new_from_map(vec![3, 4, 5], &bond_dims);
255        let tensor1 = Tensor::new_from_map(vec![0, 1, 3, 4], &bond_dims);
256        let tensor2 = Tensor::new_from_map(vec![0, 1, 2, 5, 6], &bond_dims);
257        vec![tensor0, tensor1, tensor2]
258    }
259
260    #[test]
261    fn test_greedy_communication() {
262        let tensors = setup_simple();
263        let latency_map = setup_simple_partition_data();
264        let communication_scheme = greedy(&tensors, &latency_map);
265
266        assert_eq!(&communication_scheme, &[(0, 1), (0, 2)]);
267        let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
268        let (flop_cost, mem_cost) = communication_path_cost(
269            &tensors,
270            &communication_scheme,
271            true,
272            true,
273            Some(&tensor_costs),
274        );
275        assert_eq!(flop_cost, 104.);
276        assert_eq!(mem_cost, 44.);
277    }
278
279    #[test]
280    fn test_weighted_communication() {
281        let tensors = setup_simple();
282        let latency_map = setup_simple_partition_data();
283
284        let communication_scheme = weighted_branchbound(&tensors, &latency_map);
285
286        assert_eq!(&communication_scheme, &[(1, 0), (2, 1)]);
287        // Flop Cost: (1, 0) = 32 , Tensor cost = 40, Total = 72
288        // Flop Cost: (2, 1) = 32, Tensor cost = 50
289        // max(72, 50) + 32 = 104
290        // Mem Cost: (2, 1) = 2^3 + 2^5 + 2^2 = 44
291        let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
292        let (flop_cost, mem_cost) = communication_path_cost(
293            &tensors,
294            &communication_scheme,
295            true,
296            true,
297            Some(&tensor_costs),
298        );
299
300        assert_eq!(flop_cost, 104.);
301        assert_eq!(mem_cost, 44.);
302    }
303
304    #[test]
305    fn test_bi_partition_communication() {
306        let tensors = setup_simple();
307        let latency_map = setup_simple_partition_data();
308
309        let communication_scheme = bipartition(&tensors, &latency_map);
310
311        assert_eq!(&communication_scheme, &[(2, 1), (2, 0)]);
312
313        // Flop Cost: (2, 1) = 128, Tensor cost = 50, Total = 178
314        // Flop Cost: (2, 0) = 32 , Tensor cost = 40
315        // max(178, 40) + 32 = 210
316        // Mem Cost: (2, 1) = 2^4 + 2^5 + 2^5 = 80
317        let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
318        let (flop_cost, mem_cost) = communication_path_cost(
319            &tensors,
320            &communication_scheme,
321            true,
322            true,
323            Some(&tensor_costs),
324        );
325
326        assert_eq!(flop_cost, 210.);
327        assert_eq!(mem_cost, 80.);
328    }
329}