1use num_complex::Complex64;
5
6use crate::{
7 contractionpath::{ContractionPath, SimplePathRef},
8 tensornetwork::tensor::{EdgeIndex, Tensor},
9};
10
11pub fn contract_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
27 let final_dims = t_1 ^ t_2;
28 let shared_dims = t_1 & t_2;
29
30 let single_loop_cost = shared_dims.size();
31 (single_loop_cost - 1f64).mul_add(2f64, single_loop_cost * 6f64) * final_dims.size()
32}
33
34#[inline]
50pub fn contract_op_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
51 let all_dims = t_1 | t_2;
52 all_dims.size()
53}
54
55#[inline]
71pub fn contract_size_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
72 let diff = t_1 ^ t_2;
73 diff.size() + t_1.size() + t_2.size()
74}
75
76pub fn contract_size_tensors_exact(i: &Tensor, j: &Tensor) -> f64 {
96 #[inline]
98 fn is_prefix(prefix: &[EdgeIndex], list: &[EdgeIndex]) -> bool {
99 if prefix.len() > list.len() {
100 return false;
101 }
102 list.iter().zip(prefix.iter()).all(|(a, b)| a == b)
103 }
104
105 #[inline]
107 fn is_suffix(suffix: &[EdgeIndex], list: &[EdgeIndex]) -> bool {
108 if suffix.len() > list.len() {
109 return false;
110 }
111 list.iter()
112 .rev()
113 .zip(suffix.iter().rev())
114 .all(|(a, b)| a == b)
115 }
116
117 let ij = i ^ j;
118 let contracted_legs = i & j;
119 let i_needs_transpose = !is_suffix(contracted_legs.legs(), i.legs());
120 let j_needs_transpose = !is_prefix(contracted_legs.legs(), j.legs());
121
122 let i_size = i.size();
123 let j_size = j.size();
124 let ij_size = ij.size();
125
126 let elements = match (i_needs_transpose, j_needs_transpose) {
127 (true, true) => (2.0 * i_size + j_size)
128 .max(i_size + 2.0 * j_size)
129 .max(i_size + j_size + ij_size),
130 (true, false) => (2.0 * i_size + j_size).max(i_size + j_size + ij_size),
131 (false, true) => (i_size + 2.0 * j_size).max(i_size + j_size + ij_size),
132 (false, false) => i_size + j_size + ij_size,
133 };
134
135 elements * std::mem::size_of::<Complex64>() as f64
136}
137
138#[inline]
146pub fn contract_path_cost(
147 inputs: &[Tensor],
148 contract_path: &ContractionPath,
149 only_count_ops: bool,
150) -> (f64, f64) {
151 let cost_function = if only_count_ops {
152 contract_op_cost_tensors
153 } else {
154 contract_cost_tensors
155 };
156 contract_path_custom_cost(inputs, contract_path, cost_function, contract_size_tensors)
157}
158
159fn contract_path_custom_cost(
167 inputs: &[Tensor],
168 contract_path: &ContractionPath,
169 cost_function: fn(&Tensor, &Tensor) -> f64,
170 size_function: fn(&Tensor, &Tensor) -> f64,
171) -> (f64, f64) {
172 let mut op_cost = 0f64;
173 let mut mem_cost = 0f64;
174 let mut inputs = inputs.to_vec();
175
176 for (i, path) in &contract_path.nested {
177 let costs =
178 contract_path_custom_cost(inputs[*i].tensors(), path, cost_function, size_function);
179 op_cost += costs.0;
180 mem_cost = mem_cost.max(costs.1);
181 inputs[*i] = inputs[*i].external_tensor();
182 }
183
184 for &(i, j) in &contract_path.toplevel {
185 op_cost += cost_function(&inputs[i], &inputs[j]);
186 let ij = &inputs[i] ^ &inputs[j];
187 let new_mem_cost = size_function(&inputs[i], &inputs[j]);
188 mem_cost = mem_cost.max(new_mem_cost);
189 inputs[i] = ij;
190 }
191
192 (op_cost, mem_cost)
193}
194
195#[inline]
198pub fn communication_path_op_costs(
199 inputs: &[Tensor],
200 contract_path: SimplePathRef,
201 only_count_ops: bool,
202 tensor_cost: Option<&[f64]>,
203) -> ((f64, f64), f64) {
204 let (parallel_cost, _) =
205 communication_path_cost(inputs, contract_path, only_count_ops, true, tensor_cost);
206 let (serial_cost, mem_cost) =
207 communication_path_cost(inputs, contract_path, only_count_ops, false, tensor_cost);
208 ((parallel_cost, serial_cost), mem_cost)
209}
210
211pub fn communication_path_cost(
221 inputs: &[Tensor],
222 contract_path: SimplePathRef,
223 only_count_ops: bool,
224 only_critical_path: bool,
225 tensor_cost: Option<&[f64]>,
226) -> (f64, f64) {
227 let cost_function = if only_count_ops {
228 contract_op_cost_tensors
229 } else {
230 contract_cost_tensors
231 };
232 let tensor_cost = if let Some(tensor_cost) = tensor_cost {
233 assert_eq!(inputs.len(), tensor_cost.len());
234 tensor_cost
235 } else {
236 &vec![0f64; inputs.len()]
237 };
238 if inputs.len() == 1 {
239 return (tensor_cost[0], tensor_cost[0]);
240 }
241
242 communication_path_custom_cost(
243 inputs,
244 contract_path,
245 cost_function,
246 only_critical_path,
247 tensor_cost,
248 )
249}
250
251fn communication_path_custom_cost(
260 inputs: &[Tensor],
261 contract_path: SimplePathRef,
262 cost_function: fn(&Tensor, &Tensor) -> f64,
263 only_critical_path: bool,
264 tensor_cost: &[f64],
265) -> (f64, f64) {
266 let mut op_cost = 0f64;
267 let mut mem_cost = 0f64;
268 let mut inputs = inputs.to_vec();
269 let mut tensor_cost = tensor_cost.to_vec();
270
271 for &(i, j) in contract_path {
272 let ij = &inputs[i] ^ &inputs[j];
273 let new_mem_cost = contract_size_tensors(&inputs[i], &inputs[j]);
274 mem_cost = mem_cost.max(new_mem_cost);
275
276 op_cost = if only_critical_path {
277 cost_function(&inputs[i], &inputs[j]) + tensor_cost[i].max(tensor_cost[j])
278 } else {
279 cost_function(&inputs[i], &inputs[j]) + tensor_cost[i] + tensor_cost[j]
280 };
281 tensor_cost[i] = op_cost;
282 inputs[i] = ij;
283 }
284
285 (op_cost, mem_cost)
286}
287
288#[inline]
296pub fn compute_memory_requirements(
297 inputs: &[Tensor],
298 contract_path: &ContractionPath,
299 memory_estimator: fn(&Tensor, &Tensor) -> f64,
300) -> f64 {
301 fn id(_: &Tensor, _: &Tensor) -> f64 {
302 0.0
303 }
304 let (_, mem) = contract_path_custom_cost(inputs, contract_path, id, memory_estimator);
305 mem
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 use rustc_hash::FxHashMap;
313
314 use crate::path;
315 use crate::tensornetwork::tensor::Tensor;
316
317 fn setup_simple() -> Tensor {
318 let bond_dims =
319 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
320 Tensor::new_composite(vec![
321 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
322 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
323 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
324 ])
325 }
326
327 fn setup_complex() -> Tensor {
328 let bond_dims = FxHashMap::from_iter([
329 (0, 5),
330 (1, 2),
331 (2, 6),
332 (3, 8),
333 (4, 1),
334 (5, 3),
335 (6, 4),
336 (7, 3),
337 (8, 2),
338 (9, 2),
339 ]);
340 let t1_tensors = vec![
341 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
342 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
343 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
344 ];
345 let t1 = Tensor::new_composite(t1_tensors);
346
347 let t2_tensors = vec![
348 Tensor::new_from_map(vec![5, 6, 8], &bond_dims),
349 Tensor::new_from_map(vec![7, 8, 9], &bond_dims),
350 ];
351 let t2 = Tensor::new_composite(t2_tensors);
352 Tensor::new_composite(vec![t1, t2])
353 }
354
355 fn setup_parallel() -> Tensor {
356 let bond_dims =
357 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
358 Tensor::new_composite(vec![
359 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
360 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
361 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
362 Tensor::new_from_map(vec![5, 6], &bond_dims),
363 ])
364 }
365
366 #[test]
367 fn test_contract_path_cost() {
368 let tn = setup_simple();
369 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], false);
370 assert_eq!(op_cost, 4540.);
371 assert_eq!(mem_cost, 538.);
372 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], false);
373 assert_eq!(op_cost, 49296.);
374 assert_eq!(mem_cost, 1176.);
375 }
376
377 #[test]
378 fn test_contract_complex_path_cost() {
379 let tn = setup_complex();
380 let (op_cost, mem_cost) = contract_path_cost(
381 tn.tensors(),
382 &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
383 false,
384 );
385 assert_eq!(op_cost, 11188.);
386 assert_eq!(mem_cost, 538.);
387 }
388
389 #[test]
390 fn test_contract_path_cost_only_ops() {
391 let tn = setup_simple();
392 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], true);
393 assert_eq!(op_cost, 600.);
394 assert_eq!(mem_cost, 538.);
395 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], true);
396 assert_eq!(op_cost, 6336.);
397 assert_eq!(mem_cost, 1176.);
398 }
399
400 #[test]
401 fn test_contract_path_complex_cost_only_ops() {
402 let tn = setup_complex();
403 let (op_cost, mem_cost) = contract_path_cost(
404 tn.tensors(),
405 &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
406 true,
407 );
408 assert_eq!(op_cost, 1464.);
409 assert_eq!(mem_cost, 538.);
410 }
411
412 #[test]
413 fn test_communication_path_cost_only_ops() {
414 let tn = setup_parallel();
415 let (op_cost, mem_cost) =
416 communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 2)], true, true, None);
417 assert_eq!(op_cost, 490.);
418 assert_eq!(mem_cost, 538.);
419 }
420
421 #[test]
422 fn test_communication_path_cost() {
423 let tn = setup_parallel();
424 let (op_cost, mem_cost) =
425 communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 1)], false, true, None);
426 assert_eq!(op_cost, 7564.);
427 assert_eq!(mem_cost, 538.);
428 }
429
430 #[test]
431 fn test_communication_path_cost_only_ops_with_partition_cost() {
432 let tn = setup_parallel();
433 let tensor_cost = vec![20., 30., 80., 10.];
434 let (op_cost, mem_cost) = communication_path_cost(
435 tn.tensors(),
436 &[(0, 1), (2, 3), (0, 2)],
437 true,
438 true,
439 Some(&tensor_cost),
440 );
441 assert_eq!(op_cost, 520.);
442 assert_eq!(mem_cost, 538.);
443 }
444
445 #[test]
446 fn test_communication_path_cost_with_partition_cost() {
447 let tn = setup_parallel();
448 let tensor_cost = vec![20., 30., 80., 10.];
449 let (op_cost, mem_cost) = communication_path_cost(
450 tn.tensors(),
451 &[(0, 1), (2, 3), (0, 1)],
452 false,
453 true,
454 Some(&tensor_cost),
455 );
456 assert_eq!(op_cost, 7594.);
457 assert_eq!(mem_cost, 538.);
458 }
459}