tnc/tensornetwork/
partitioning.rs

1//! Functionality to partition composite tensors (i.e., tensor networks) into
2//! multiple tensor networks.
3use std::iter::zip;
4
5use itertools::Itertools;
6use kahypar::{partition, KaHyParContext};
7use rustc_hash::FxHashMap;
8
9use crate::tensornetwork::partitioning::partition_config::PartitioningStrategy;
10use crate::tensornetwork::tensor::Tensor;
11
12pub mod partition_config;
13
14/// The scale factor for the log weights used in the partitioning.
15/// This is used to convert the log weights to integer weights for KaHyPar.
16/// The current value allows for bond dimensions up to 144497 before two consecutive
17/// bond dimensions are rounded to the same weight. At the same time, it allows for
18/// the sum of ~150000 such bond dims before an i32 overflow occurs.
19const LOG_SCALE_FACTOR: f64 = 1e5;
20
21/// Partitions input tensor network using `KaHyPar` library.
22///
23/// Returns a `Vec<usize>` of length equal to the number of input tensors storing final partitioning results.
24/// The usize associated with each Tensor indicates its partionining.
25///
26/// # Arguments
27///
28/// * `tensor_network` - [`Tensor`] to be partitionined
29/// * `k` - imbalance parameter for `KaHyPar`
30/// * `partition_strategy` - The strategy to pass to `KaHyPar`
31/// * `min` - if `true` performs `min_cut` to partition tensor network, if `false`, uses `max_cut`
32///
33pub fn find_partitioning(
34    tensor_network: &Tensor,
35    k: i32,
36    partitioning_strategy: PartitioningStrategy,
37    min: bool,
38) -> Vec<usize> {
39    if k == 1 {
40        return vec![0; tensor_network.tensors().len()];
41    }
42
43    let num_vertices = tensor_network.tensors().len() as u32;
44    let mut context = KaHyParContext::new();
45    partitioning_strategy.apply(&mut context);
46
47    let x = if min { 1 } else { -1 };
48
49    let imbalance = 0.03;
50    let mut objective = 0;
51    let mut hyperedge_weights = vec![];
52    let mut hyperedge_indices = vec![0];
53    let mut hyperedges = vec![];
54
55    let mut edge_dict = FxHashMap::default();
56    for (tensor_id, tensor) in tensor_network.tensors().iter().enumerate() {
57        for (leg, dim) in tensor.edges() {
58            edge_dict
59                .entry(leg)
60                .and_modify(|entry| {
61                    hyperedges.push(*entry as u32);
62                    hyperedges.push(tensor_id as u32);
63                    hyperedge_indices.push(hyperedge_indices.last().unwrap() + 2);
64                    // Use log weights, because KaHyPar minimizes the sum of weights while we need the product.
65                    // Since it accepts only integer weights, we scale the log values before rounding.
66                    let weight = LOG_SCALE_FACTOR * (*dim as f64).log2();
67                    hyperedge_weights.push(x * weight as i32);
68                    *entry = tensor_id;
69                })
70                .or_insert(tensor_id);
71        }
72    }
73
74    let mut partitioning = vec![-1; num_vertices as usize];
75    partition(
76        num_vertices,
77        hyperedge_weights.len() as u32,
78        imbalance,
79        k,
80        None,
81        Some(hyperedge_weights),
82        &hyperedge_indices,
83        hyperedges.as_slice(),
84        &mut objective,
85        &mut context,
86        &mut partitioning,
87    );
88    partitioning.iter().map(|e| *e as usize).collect()
89}
90
91/// Repeatedly partitions a tensor network to identify a communication scheme.
92/// Returns a `Vec<usize>` of length equal to the number of input tensors minus one, acts as a communication scheme.
93///
94/// # Arguments
95///
96/// * `tensors` - &[(usize, `Tensor`)] to be partitioned. each tuple contains the intermediate contraction cost and intermediate tensor for communication.
97/// * `k` - number of partitions
98/// * `partitioning_strategy` - The strategy to pass to `KaHyPar`
99/// * `min` - if `true` performs `min_cut` to partition tensor network, if `false`, uses `max_cut`
100///
101pub fn communication_partitioning(
102    tensors: &[(usize, Tensor)],
103    k: i32,
104    imbalance: f64,
105    partitioning_strategy: PartitioningStrategy,
106    min: bool,
107) -> Vec<usize> {
108    assert!(k > 1, "Partitioning only valid for more than one process");
109    let num_vertices = tensors.len() as u32;
110    let mut context = KaHyParContext::new();
111    partitioning_strategy.apply(&mut context);
112
113    let x = if min { 1 } else { -1 };
114
115    let mut objective = 0;
116    let mut hyperedge_weights = vec![];
117
118    let mut hyperedge_indices = vec![0];
119    let mut hyperedges = vec![];
120
121    // Bidirectional mapping to a new index as KaHyPar indexes from 0.
122    // let mut edge_to_virtual_edge = FxHashMap::default();
123    // New index that starts from 0
124    // let mut edge_count = 0;
125    let mut edge_dict = FxHashMap::default();
126    for (tensor_id, (_, tensor)) in tensors.iter().enumerate() {
127        for (leg, dim) in tensor.edges() {
128            edge_dict
129                .entry(leg)
130                .and_modify(|entry| {
131                    hyperedges.push(*entry as u32);
132                    hyperedges.push(tensor_id as u32);
133                    hyperedge_indices.push(hyperedge_indices.last().unwrap() + 2);
134                    // Use log weights, because KaHyPar minimizes the sum of weights while we need the product.
135                    // Since it accepts only integer weights, we scale the log values before rounding.
136                    let weight = LOG_SCALE_FACTOR * (*dim as f64).log2();
137                    hyperedge_weights.push(x * weight as i32);
138                    *entry = tensor_id;
139                })
140                .or_insert(tensor_id);
141        }
142    }
143
144    let mut partitioning = vec![-1; num_vertices as usize];
145    partition(
146        num_vertices,
147        hyperedge_weights.len() as u32,
148        imbalance,
149        k,
150        None,
151        Some(hyperedge_weights),
152        &hyperedge_indices,
153        hyperedges.as_slice(),
154        &mut objective,
155        &mut context,
156        &mut partitioning,
157    );
158
159    // partitioning
160    partitioning.iter().map(|e| *e as usize).collect()
161}
162
163/// Partitions the tensor network based on the `partitioning` vector that assigns
164/// each vector to a partition.
165pub fn partition_tensor_network(tn: Tensor, partitioning: &[usize]) -> Tensor {
166    let partition_ids = partitioning.iter().unique().copied().collect_vec();
167    let partition_dict =
168        zip(partition_ids.iter().copied(), 0..partition_ids.len()).collect::<FxHashMap<_, _>>();
169
170    let mut partitions = vec![Tensor::default(); partition_ids.len()];
171    for (partition_id, tensor) in zip(partitioning, tn.tensors) {
172        partitions[partition_dict[partition_id]].push_tensor(tensor);
173    }
174    Tensor::new_composite(partitions)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    use float_cmp::assert_approx_eq;
182    use rustc_hash::FxHashMap;
183
184    use crate::tensornetwork::partitioning::partition_config::PartitioningStrategy;
185    use crate::tensornetwork::tensor::{EdgeIndex, Tensor};
186
187    fn setup_complex() -> (Tensor, FxHashMap<EdgeIndex, u64>) {
188        let bond_dims = FxHashMap::from_iter([
189            (0, 27),
190            (1, 18),
191            (2, 12),
192            (3, 15),
193            (4, 5),
194            (5, 3),
195            (6, 18),
196            (7, 22),
197            (8, 45),
198            (9, 65),
199            (10, 5),
200            (11, 17),
201        ]);
202        (
203            Tensor::new_composite(vec![
204                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
205                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
206                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
207                Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
208                Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
209                Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
210            ]),
211            bond_dims,
212        )
213    }
214
215    #[test]
216    fn test_simple_partitioning() {
217        let (tn, bond_dims) = setup_complex();
218        let ref_tensor_1 = Tensor::new_composite(vec![
219            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
220            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
221        ]);
222        let ref_tensor_2 = Tensor::new_composite(vec![
223            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
224            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
225        ]);
226        let ref_tensor_3 = Tensor::new_composite(vec![
227            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
228            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
229        ]);
230        let partitioning = find_partitioning(&tn, 3, PartitioningStrategy::MinCut, true);
231        assert_eq!(partitioning, [2, 1, 2, 0, 0, 1]);
232        let partitioned_tn = partition_tensor_network(tn, partitioning.as_slice());
233        assert_eq!(partitioned_tn.tensors().len(), 3);
234
235        assert_approx_eq!(&Tensor, partitioned_tn.tensor(2), &ref_tensor_1);
236        assert_approx_eq!(&Tensor, partitioned_tn.tensor(1), &ref_tensor_2);
237        assert_approx_eq!(&Tensor, partitioned_tn.tensor(0), &ref_tensor_3);
238    }
239
240    #[test]
241    fn test_single_partition() {
242        let (tn, _) = setup_complex();
243        let partitioning = find_partitioning(&tn, 1, PartitioningStrategy::MinCut, true);
244        assert_eq!(partitioning, [0, 0, 0, 0, 0, 0]);
245    }
246}