1use std::iter::zip;
4
5use itertools::Itertools;
6use kahypar::{partition, KaHyParContext};
7use rustc_hash::FxHashMap;
8
9use crate::tensornetwork::partitioning::partition_config::PartitioningStrategy;
10use crate::tensornetwork::tensor::Tensor;
11
12pub mod partition_config;
13
14const LOG_SCALE_FACTOR: f64 = 1e5;
20
21pub fn find_partitioning(
34 tensor_network: &Tensor,
35 k: i32,
36 partitioning_strategy: PartitioningStrategy,
37 min: bool,
38) -> Vec<usize> {
39 if k == 1 {
40 return vec![0; tensor_network.tensors().len()];
41 }
42
43 let num_vertices = tensor_network.tensors().len() as u32;
44 let mut context = KaHyParContext::new();
45 partitioning_strategy.apply(&mut context);
46
47 let x = if min { 1 } else { -1 };
48
49 let imbalance = 0.03;
50 let mut objective = 0;
51 let mut hyperedge_weights = vec![];
52 let mut hyperedge_indices = vec![0];
53 let mut hyperedges = vec![];
54
55 let mut edge_dict = FxHashMap::default();
56 for (tensor_id, tensor) in tensor_network.tensors().iter().enumerate() {
57 for (leg, dim) in tensor.edges() {
58 edge_dict
59 .entry(leg)
60 .and_modify(|entry| {
61 hyperedges.push(*entry as u32);
62 hyperedges.push(tensor_id as u32);
63 hyperedge_indices.push(hyperedge_indices.last().unwrap() + 2);
64 let weight = LOG_SCALE_FACTOR * (*dim as f64).log2();
67 hyperedge_weights.push(x * weight as i32);
68 *entry = tensor_id;
69 })
70 .or_insert(tensor_id);
71 }
72 }
73
74 let mut partitioning = vec![-1; num_vertices as usize];
75 partition(
76 num_vertices,
77 hyperedge_weights.len() as u32,
78 imbalance,
79 k,
80 None,
81 Some(hyperedge_weights),
82 &hyperedge_indices,
83 hyperedges.as_slice(),
84 &mut objective,
85 &mut context,
86 &mut partitioning,
87 );
88 partitioning.iter().map(|e| *e as usize).collect()
89}
90
91pub fn communication_partitioning(
102 tensors: &[(usize, Tensor)],
103 k: i32,
104 imbalance: f64,
105 partitioning_strategy: PartitioningStrategy,
106 min: bool,
107) -> Vec<usize> {
108 assert!(k > 1, "Partitioning only valid for more than one process");
109 let num_vertices = tensors.len() as u32;
110 let mut context = KaHyParContext::new();
111 partitioning_strategy.apply(&mut context);
112
113 let x = if min { 1 } else { -1 };
114
115 let mut objective = 0;
116 let mut hyperedge_weights = vec![];
117
118 let mut hyperedge_indices = vec![0];
119 let mut hyperedges = vec![];
120
121 let mut edge_dict = FxHashMap::default();
126 for (tensor_id, (_, tensor)) in tensors.iter().enumerate() {
127 for (leg, dim) in tensor.edges() {
128 edge_dict
129 .entry(leg)
130 .and_modify(|entry| {
131 hyperedges.push(*entry as u32);
132 hyperedges.push(tensor_id as u32);
133 hyperedge_indices.push(hyperedge_indices.last().unwrap() + 2);
134 let weight = LOG_SCALE_FACTOR * (*dim as f64).log2();
137 hyperedge_weights.push(x * weight as i32);
138 *entry = tensor_id;
139 })
140 .or_insert(tensor_id);
141 }
142 }
143
144 let mut partitioning = vec![-1; num_vertices as usize];
145 partition(
146 num_vertices,
147 hyperedge_weights.len() as u32,
148 imbalance,
149 k,
150 None,
151 Some(hyperedge_weights),
152 &hyperedge_indices,
153 hyperedges.as_slice(),
154 &mut objective,
155 &mut context,
156 &mut partitioning,
157 );
158
159 partitioning.iter().map(|e| *e as usize).collect()
161}
162
163pub fn partition_tensor_network(tn: Tensor, partitioning: &[usize]) -> Tensor {
166 let partition_ids = partitioning.iter().unique().copied().collect_vec();
167 let partition_dict =
168 zip(partition_ids.iter().copied(), 0..partition_ids.len()).collect::<FxHashMap<_, _>>();
169
170 let mut partitions = vec![Tensor::default(); partition_ids.len()];
171 for (partition_id, tensor) in zip(partitioning, tn.tensors) {
172 partitions[partition_dict[partition_id]].push_tensor(tensor);
173 }
174 Tensor::new_composite(partitions)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 use float_cmp::assert_approx_eq;
182 use rustc_hash::FxHashMap;
183
184 use crate::tensornetwork::partitioning::partition_config::PartitioningStrategy;
185 use crate::tensornetwork::tensor::{EdgeIndex, Tensor};
186
187 fn setup_complex() -> (Tensor, FxHashMap<EdgeIndex, u64>) {
188 let bond_dims = FxHashMap::from_iter([
189 (0, 27),
190 (1, 18),
191 (2, 12),
192 (3, 15),
193 (4, 5),
194 (5, 3),
195 (6, 18),
196 (7, 22),
197 (8, 45),
198 (9, 65),
199 (10, 5),
200 (11, 17),
201 ]);
202 (
203 Tensor::new_composite(vec![
204 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
205 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
206 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
207 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
208 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
209 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
210 ]),
211 bond_dims,
212 )
213 }
214
215 #[test]
216 fn test_simple_partitioning() {
217 let (tn, bond_dims) = setup_complex();
218 let ref_tensor_1 = Tensor::new_composite(vec![
219 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
220 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
221 ]);
222 let ref_tensor_2 = Tensor::new_composite(vec![
223 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
224 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
225 ]);
226 let ref_tensor_3 = Tensor::new_composite(vec![
227 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
228 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
229 ]);
230 let partitioning = find_partitioning(&tn, 3, PartitioningStrategy::MinCut, true);
231 assert_eq!(partitioning, [2, 1, 2, 0, 0, 1]);
232 let partitioned_tn = partition_tensor_network(tn, partitioning.as_slice());
233 assert_eq!(partitioned_tn.tensors().len(), 3);
234
235 assert_approx_eq!(&Tensor, partitioned_tn.tensor(2), &ref_tensor_1);
236 assert_approx_eq!(&Tensor, partitioned_tn.tensor(1), &ref_tensor_2);
237 assert_approx_eq!(&Tensor, partitioned_tn.tensor(0), &ref_tensor_3);
238 }
239
240 #[test]
241 fn test_single_partition() {
242 let (tn, _) = setup_complex();
243 let partitioning = find_partitioning(&tn, 1, PartitioningStrategy::MinCut, true);
244 assert_eq!(partitioning, [0, 0, 0, 0, 0, 0]);
245 }
246}