tnc/contractionpath/
communication_schemes.rs

1use std::fmt;
2
3use itertools::Itertools;
4use rand::distr::Uniform;
5use rand::Rng;
6use rustc_hash::FxHashMap;
7
8use crate::contractionpath::contraction_cost::communication_path_cost;
9use crate::contractionpath::paths::cotengrust::{Cotengrust, OptMethod};
10use crate::contractionpath::paths::weighted_branchbound::WeightedBranchBound;
11use crate::contractionpath::paths::{CostType, FindPath};
12use crate::contractionpath::SimplePath;
13use crate::tensornetwork::partitioning::{communication_partitioning, PartitioningStrategy};
14use crate::tensornetwork::tensor::Tensor;
15
16/// The scheme used to find a contraction path for the final fan-in of tensors
17/// between MPI ranks.
18#[derive(Debug, Copy, Clone)]
19pub enum CommunicationScheme {
20    /// Uses Greedy scheme to find contraction path for communication
21    Greedy,
22    /// Uses a randomized greedy approach
23    RandomGreedy,
24    /// Uses repeated bipartitioning to identify communication path
25    Bipartition,
26    /// Uses repeated bipartitioning to identify communication path
27    BipartitionSweep,
28    /// Uses a filtered search that considered time to intermediate tensor
29    WeightedBranchBound,
30    /// Uses a filtered search
31    BranchBound,
32}
33
34impl fmt::Display for CommunicationScheme {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        let comm_str = match self {
37            Self::Greedy => "greedy",
38            Self::RandomGreedy => "random_greedy",
39            Self::Bipartition => "bipartition",
40            Self::BipartitionSweep => "bipartition_sweep",
41            Self::WeightedBranchBound => "weightedbranchbound",
42            Self::BranchBound => "branchbound",
43        };
44        f.write_str(comm_str)
45    }
46}
47
48impl CommunicationScheme {
49    pub(crate) fn communication_path<R>(
50        &self,
51        children_tensors: &[Tensor],
52        latency_map: &FxHashMap<usize, f64>,
53        rng: Option<&mut R>,
54    ) -> SimplePath
55    where
56        R: Rng,
57    {
58        match self {
59            Self::Greedy => greedy(children_tensors, latency_map),
60            Self::RandomGreedy => random_greedy(children_tensors),
61            Self::Bipartition => bipartition(children_tensors, latency_map),
62            Self::BipartitionSweep => {
63                let Some(rng) = rng else {
64                    panic!("BipartitionSweep requires a random number generator")
65                };
66                bipartition_sweep(children_tensors, latency_map, rng)
67            }
68
69            Self::WeightedBranchBound => weighted_branchbound(children_tensors, latency_map),
70            Self::BranchBound => branchbound(children_tensors),
71        }
72    }
73}
74
75fn greedy(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
76    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
77    let mut opt = Cotengrust::new(&communication_tensors, OptMethod::Greedy);
78    opt.find_path();
79    opt.get_best_replace_path().into_simple()
80}
81
82fn bipartition(children_tensors: &[Tensor], _latency_map: &FxHashMap<usize, f64>) -> SimplePath {
83    let children_tensors = children_tensors.iter().cloned().enumerate().collect_vec();
84    let imbalance = 0.03;
85    tensor_bipartition(&children_tensors, imbalance)
86}
87
88fn bipartition_sweep<R>(
89    children_tensors: &[Tensor],
90    latency_map: &FxHashMap<usize, f64>,
91    rng: &mut R,
92) -> SimplePath
93where
94    R: Rng,
95{
96    let tensors = children_tensors.iter().cloned().enumerate().collect_vec();
97    let mut best_flops = f64::INFINITY;
98    let mut best_path = vec![];
99    let partition_latencies = latency_map
100        .iter()
101        .sorted_by_key(|(k, _)| **k)
102        .map(|(_, v)| *v)
103        .collect::<Vec<_>>();
104    for _ in 0..20 {
105        let imbalance = rng.sample(Uniform::new(0.01, 0.5).unwrap());
106        let path = tensor_bipartition(&tensors, imbalance);
107        let (flops, _) = communication_path_cost(
108            children_tensors,
109            &path,
110            true,
111            true,
112            Some(&partition_latencies),
113        );
114        if flops < best_flops {
115            best_flops = flops;
116            best_path = path;
117        }
118    }
119    best_path
120}
121
122fn weighted_branchbound(
123    children_tensors: &[Tensor],
124    latency_map: &FxHashMap<usize, f64>,
125) -> SimplePath {
126    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
127
128    let mut opt = WeightedBranchBound::new(
129        &communication_tensors,
130        Some(10),
131        5.,
132        latency_map.clone(),
133        CostType::Flops,
134    );
135    opt.find_path();
136    opt.get_best_replace_path().into_simple()
137}
138
139fn branchbound(children_tensors: &[Tensor]) -> SimplePath {
140    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
141    let latency_map = (0..children_tensors.len()).map(|i| (i, 0.0)).collect();
142
143    let mut opt = WeightedBranchBound::new(
144        &communication_tensors,
145        Some(10),
146        5.,
147        latency_map,
148        CostType::Flops,
149    );
150    opt.find_path();
151    opt.get_best_replace_path().into_simple()
152}
153
154/// Uses recursive bipartitioning to identify a communication scheme for final tensors
155/// Returns root id of subtree, parallel contraction cost as f64, resultant tensor and prior contraction sequence
156fn tensor_bipartition_recursive(
157    children_tensor: &[(usize, Tensor)],
158    imbalance: f64,
159) -> (usize, Tensor, SimplePath) {
160    let k = 2;
161    let min = true;
162
163    // Composite tensor contracts with a single leaf tensor
164    if children_tensor.len() == 1 {
165        return (
166            children_tensor[0].0,
167            children_tensor[0].1.clone(),
168            Vec::new(),
169        );
170    }
171
172    // Only occurs when there is a subset of 2 tensors
173    if children_tensor.len() == 2 {
174        // Always ensure that the larger tensor size is on the left.
175        let (t1, t2) = if children_tensor[1].1.size() > children_tensor[0].1.size() {
176            (children_tensor[1].0, children_tensor[0].0)
177        } else {
178            (children_tensor[0].0, children_tensor[1].0)
179        };
180        let tensor = &children_tensor[0].1 ^ &children_tensor[1].1;
181
182        return (t1, tensor, vec![(t1, t2)]);
183    }
184
185    let partitioning = communication_partitioning(
186        children_tensor,
187        k,
188        imbalance,
189        PartitioningStrategy::MinCut,
190        min,
191    );
192
193    let mut partition_iter = partitioning.iter();
194    let (children_1, children_2): (Vec<_>, Vec<_>) = children_tensor
195        .iter()
196        .cloned()
197        .partition(|_| partition_iter.next() == Some(&0));
198
199    let (id_1, t1, mut contraction_1) = tensor_bipartition_recursive(&children_1, imbalance);
200
201    let (id_2, t2, mut contraction_2) = tensor_bipartition_recursive(&children_2, imbalance);
202
203    let tensor = &t1 ^ &t2;
204
205    contraction_1.append(&mut contraction_2);
206    let (id_1, id_2) = if t2.size() > t1.size() {
207        (id_2, id_1)
208    } else {
209        (id_1, id_2)
210    };
211
212    contraction_1.push((id_1, id_2));
213    (id_1, tensor, contraction_1)
214}
215
216/// Repeatedly bipartitions tensor network to obtain communication scheme
217/// Assumes that all tensors contracted do so in parallel
218fn tensor_bipartition(children_tensor: &[(usize, Tensor)], imbalance: f64) -> SimplePath {
219    let (_, _, contraction_path) = tensor_bipartition_recursive(children_tensor, imbalance);
220    contraction_path
221}
222
223fn random_greedy(children_tensors: &[Tensor]) -> SimplePath {
224    let communication_tensors = Tensor::new_composite(children_tensors.to_vec());
225
226    let mut opt = Cotengrust::new(&communication_tensors, OptMethod::RandomGreedy(100));
227    opt.find_path();
228    opt.get_best_replace_path().into_simple()
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    use itertools::Itertools;
236    use rustc_hash::FxHashMap;
237
238    use crate::{
239        contractionpath::contraction_cost::communication_path_cost, tensornetwork::tensor::Tensor,
240    };
241
242    fn setup_simple_partition_data() -> FxHashMap<usize, f64> {
243        FxHashMap::from_iter([(0, 40.), (1, 30.), (2, 50.)])
244    }
245
246    /// Tensor ids in contraction tree included in variable name for easy tracking
247    /// This example prioritizes contracting tensor1 & tensor 2 using the greedy cost function
248    /// However, the partition cost of tensor 2 is very high, which makes contracting it later more attractive by reducing wait-time
249    fn setup_simple() -> Vec<Tensor> {
250        let bond_dims =
251            FxHashMap::from_iter([(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2)]);
252
253        let tensor0 = Tensor::new_from_map(vec![3, 4, 5], &bond_dims);
254        let tensor1 = Tensor::new_from_map(vec![0, 1, 3, 4], &bond_dims);
255        let tensor2 = Tensor::new_from_map(vec![0, 1, 2, 5, 6], &bond_dims);
256        vec![tensor0, tensor1, tensor2]
257    }
258
259    #[test]
260    fn test_greedy_communication() {
261        let tensors = setup_simple();
262        let latency_map = setup_simple_partition_data();
263        let communication_scheme = greedy(&tensors, &latency_map);
264
265        assert_eq!(&communication_scheme, &[(0, 1), (0, 2)]);
266        let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
267        let (flop_cost, mem_cost) = communication_path_cost(
268            &tensors,
269            &communication_scheme,
270            true,
271            true,
272            Some(&tensor_costs),
273        );
274        assert_eq!(flop_cost, 104.);
275        assert_eq!(mem_cost, 44.);
276    }
277
278    #[test]
279    fn test_weighted_communication() {
280        let tensors = setup_simple();
281        let latency_map = setup_simple_partition_data();
282
283        let communication_scheme = weighted_branchbound(&tensors, &latency_map);
284
285        assert_eq!(&communication_scheme, &[(1, 0), (2, 1)]);
286        // Flop Cost: (1, 0) = 32 , Tensor cost = 40, Total = 72
287        // Flop Cost: (2, 1) = 32, Tensor cost = 50
288        // max(72, 50) + 32 = 104
289        // Mem Cost: (2, 1) = 2^3 + 2^5 + 2^2 = 44
290        let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
291        let (flop_cost, mem_cost) = communication_path_cost(
292            &tensors,
293            &communication_scheme,
294            true,
295            true,
296            Some(&tensor_costs),
297        );
298
299        assert_eq!(flop_cost, 104.);
300        assert_eq!(mem_cost, 44.);
301    }
302
303    #[test]
304    fn test_bi_partition_communication() {
305        let tensors = setup_simple();
306        let latency_map = setup_simple_partition_data();
307
308        let communication_scheme = bipartition(&tensors, &latency_map);
309
310        assert_eq!(&communication_scheme, &[(2, 1), (2, 0)]);
311
312        // Flop Cost: (2, 1) = 128, Tensor cost = 50, Total = 178
313        // Flop Cost: (2, 0) = 32 , Tensor cost = 40
314        // max(178, 40) + 32 = 210
315        // Mem Cost: (2, 1) = 2^4 + 2^5 + 2^5 = 80
316        let tensor_costs = (0..tensors.len()).map(|i| latency_map[&i]).collect_vec();
317        let (flop_cost, mem_cost) = communication_path_cost(
318            &tensors,
319            &communication_scheme,
320            true,
321            true,
322            Some(&tensor_costs),
323        );
324
325        assert_eq!(flop_cost, 210.);
326        assert_eq!(mem_cost, 80.);
327    }
328}