1use rand::Rng;
2use rustc_hash::FxHashMap;
3
4use crate::contractionpath::contraction_tree::balancing::{find_rebalance_node, PartitionData};
5use crate::contractionpath::contraction_tree::{
6 populate_leaf_node_tensor_map, populate_subtree_tensor_map, ContractionTree,
7};
8use crate::tensornetwork::tensor::Tensor;
9
10#[derive(Debug, Clone, Copy)]
12pub enum BalancingScheme {
13 BestWorst,
15
16 Tensor,
19
20 Tensors,
25
26 AlternatingTensors,
30
31 IntermediateTensors {
36 height_limit: Option<usize>,
42 },
43
44 AlternatingIntermediateTensors {
49 height_limit: Option<usize>,
55 },
56
57 AlternatingTreeTensors {
62 height_limit: usize,
67 },
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
72pub(super) struct Shift {
73 pub from_subtree_id: usize,
75 pub to_subtree_id: usize,
77 pub moved_leaf_ids: Vec<usize>,
79}
80
81pub(super) fn best_worst<R>(
84 partition_data: &[PartitionData],
85 contraction_tree: &ContractionTree,
86 random_balance: &mut Option<(usize, R)>,
87 objective_function: fn(&Tensor, &Tensor) -> f64,
88 tensor: &Tensor,
89) -> Vec<Shift>
90where
91 R: Rng,
92{
93 let larger_subtree_id = partition_data.last().unwrap().id;
95 let smaller_subtree_id = partition_data.first().unwrap().id;
96
97 let larger_subtree_leaf_nodes =
98 populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
99
100 let smaller_subtree_leaf_nodes =
101 populate_leaf_node_tensor_map(contraction_tree, smaller_subtree_id, tensor);
102
103 let (rebalanced_node, _) = find_rebalance_node(
104 random_balance,
105 &larger_subtree_leaf_nodes,
106 &smaller_subtree_leaf_nodes,
107 objective_function,
108 );
109 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
110 vec![Shift {
111 from_subtree_id: larger_subtree_id,
112 to_subtree_id: smaller_subtree_id,
113 moved_leaf_ids: rebalanced_leaf_ids,
114 }]
115}
116
117pub(super) fn best_tensor<R>(
120 partition_data: &[PartitionData],
121 contraction_tree: &ContractionTree,
122 random_balance: &mut Option<(usize, R)>,
123 objective_function: fn(&Tensor, &Tensor) -> f64,
124 tensor: &Tensor,
125) -> Vec<Shift>
126where
127 R: Rng,
128{
129 let larger_subtree_id = partition_data.last().unwrap().id;
131
132 let larger_subtree_leaf_nodes =
133 populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
134 let (smaller_subtree_id, rebalanced_node, _) = partition_data
136 .iter()
137 .take(partition_data.len() - 1)
138 .map(|smaller| {
139 let smaller_subtree_nodes =
140 populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
141 let (rebalanced_node, objective) = find_rebalance_node(
142 random_balance,
143 &larger_subtree_leaf_nodes,
144 &smaller_subtree_nodes,
145 objective_function,
146 );
147 (smaller.id, rebalanced_node, objective)
148 })
149 .max_by(|a, b| a.2.total_cmp(&b.2))
150 .unwrap();
151
152 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
153 vec![Shift {
154 from_subtree_id: larger_subtree_id,
155 to_subtree_id: smaller_subtree_id,
156 moved_leaf_ids: rebalanced_leaf_ids,
157 }]
158}
159
160pub(super) fn best_tensors<R>(
163 partition_data: &[PartitionData],
164 contraction_tree: &ContractionTree,
165 random_balance: &mut Option<(usize, R)>,
166 objective_function: fn(&Tensor, &Tensor) -> f64,
167 tensor: &Tensor,
168) -> Vec<Shift>
169where
170 R: Rng,
171{
172 let larger_subtree_id = partition_data.last().unwrap().id;
174
175 let larger_subtree_leaf_nodes =
176 populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
177
178 let (smaller_subtree_id, rebalanced_node, _) = partition_data
180 .iter()
181 .take(partition_data.len() - 1)
182 .map(|smaller| {
183 let smaller_subtree_nodes =
184 populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
185 let (rebalanced_node, objective) = find_rebalance_node(
186 random_balance,
187 &larger_subtree_leaf_nodes,
188 &smaller_subtree_nodes,
189 objective_function,
190 );
191 (smaller.id, rebalanced_node, objective)
192 })
193 .max_by(|a, b| a.2.total_cmp(&b.2))
194 .unwrap();
195 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
196
197 let mut shifts = Vec::with_capacity(2);
198 shifts.push(Shift {
199 from_subtree_id: larger_subtree_id,
200 to_subtree_id: smaller_subtree_id,
201 moved_leaf_ids: rebalanced_leaf_ids,
202 });
203
204 let smaller_subtree_id = partition_data.first().unwrap().id;
205
206 let smaller_subtree_nodes =
207 populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
208
209 let (larger_subtree_id, rebalanced_node, _) = partition_data
210 .iter()
211 .skip(1)
212 .take(partition_data.len() - 2)
213 .map(|larger| {
214 let larger_subtree_nodes =
215 populate_leaf_node_tensor_map(contraction_tree, larger.id, tensor);
216 let (rebalanced_node, objective) = find_rebalance_node(
217 random_balance,
218 &larger_subtree_nodes,
219 &smaller_subtree_nodes,
220 objective_function,
221 );
222
223 (larger.id, rebalanced_node, objective)
224 })
225 .max_by(|(_, _, obj_a), (_, _, obj_b)| obj_a.total_cmp(obj_b))
226 .unwrap();
227
228 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
229 shifts.push(Shift {
230 from_subtree_id: larger_subtree_id,
231 to_subtree_id: smaller_subtree_id,
232 moved_leaf_ids: rebalanced_leaf_ids,
233 });
234 shifts
235}
236
237pub(super) fn tensors_odd<R>(
239 partition_data: &[PartitionData],
240 contraction_tree: &ContractionTree,
241 random_balance: &mut Option<(usize, R)>,
242 objective_function: fn(&Tensor, &Tensor) -> f64,
243 tensor: &Tensor,
244) -> Vec<Shift>
245where
246 R: Rng,
247{
248 let larger_subtree_id = partition_data.last().unwrap().id;
250
251 let larger_subtree_leaf_nodes =
252 populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
253
254 let (smaller_subtree_id, rebalanced_leaf_node, _) = partition_data
256 .iter()
257 .take(partition_data.len() - 1)
258 .map(|smaller| {
259 let smaller_subtree_nodes = FxHashMap::from_iter([(0, smaller.local_tensor.clone())]);
260 let (rebalanced_node, objective) = find_rebalance_node(
261 random_balance,
262 &larger_subtree_leaf_nodes,
263 &smaller_subtree_nodes,
264 objective_function,
265 );
266 (smaller.id, rebalanced_node, objective)
267 })
268 .max_by(|a, b| a.2.total_cmp(&b.2))
269 .unwrap();
270
271 let rebalanced_leaf_id = contraction_tree.leaf_ids(rebalanced_leaf_node);
272 vec![Shift {
273 from_subtree_id: larger_subtree_id,
274 to_subtree_id: smaller_subtree_id,
275 moved_leaf_ids: rebalanced_leaf_id,
276 }]
277}
278
279pub(super) fn tensors_even<R>(
281 partition_data: &[PartitionData],
282 contraction_tree: &ContractionTree,
283 random_balance: &mut Option<(usize, R)>,
284 objective_function: fn(&Tensor, &Tensor) -> f64,
285 tensor: &Tensor,
286) -> Vec<Shift>
287where
288 R: Rng,
289{
290 let smaller_subtree_id = partition_data.first().unwrap().id;
291
292 let smaller_subtree_nodes =
293 FxHashMap::from_iter([(0, partition_data.first().unwrap().local_tensor.clone())]);
294
295 let (larger_subtree_id, rebalanced_leaf_node, _) = partition_data
296 .iter()
297 .skip(1)
298 .map(|larger| {
299 let larger_subtree_leaf_nodes =
300 populate_leaf_node_tensor_map(contraction_tree, larger.id, tensor);
301 let (rebalanced_node, objective) = find_rebalance_node(
302 random_balance,
303 &larger_subtree_leaf_nodes,
304 &smaller_subtree_nodes,
305 objective_function,
306 );
307
308 (larger.id, rebalanced_node, objective)
309 })
310 .max_by(|a, b| a.2.total_cmp(&b.2))
311 .unwrap();
312
313 let rebalanced_leaf_id = contraction_tree.leaf_ids(rebalanced_leaf_node);
314 vec![Shift {
315 from_subtree_id: larger_subtree_id,
316 to_subtree_id: smaller_subtree_id,
317 moved_leaf_ids: rebalanced_leaf_id,
318 }]
319}
320
321pub(super) fn best_intermediate_tensors<R>(
324 partition_data: &[PartitionData],
325 contraction_tree: &ContractionTree,
326 random_balance: &mut Option<(usize, R)>,
327 objective_function: fn(&Tensor, &Tensor) -> f64,
328 tensor: &Tensor,
329 height_limit: Option<usize>,
330) -> Vec<Shift>
331where
332 R: Rng,
333{
334 let larger_subtree_id = partition_data.last().unwrap().id;
336
337 let mut larger_subtree_nodes =
339 populate_subtree_tensor_map(contraction_tree, larger_subtree_id, tensor, height_limit);
340 larger_subtree_nodes.remove(&larger_subtree_id);
341
342 let (smaller_subtree_id, first_rebalanced_node, _) = partition_data
344 .iter()
345 .take(partition_data.len() - 1)
346 .map(|smaller| {
347 let smaller_subtree_nodes =
348 populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
349 let (rebalanced_node, objective) = find_rebalance_node(
350 random_balance,
351 &larger_subtree_nodes,
352 &smaller_subtree_nodes,
353 objective_function,
354 );
355 (smaller.id, rebalanced_node, objective)
356 })
357 .max_by(|a, b| a.2.total_cmp(&b.2))
358 .unwrap();
359 let rebalanced_leaf_ids = contraction_tree.leaf_ids(first_rebalanced_node);
360
361 let mut shifts = Vec::with_capacity(2);
362 shifts.push(Shift {
363 from_subtree_id: larger_subtree_id,
364 to_subtree_id: smaller_subtree_id,
365 moved_leaf_ids: rebalanced_leaf_ids,
366 });
367
368 let smaller_subtree_id = partition_data.first().unwrap().id;
369
370 let smaller_subtree_nodes =
371 populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
372
373 let (larger_subtree_id, second_rebalanced_node, _) = partition_data
374 .iter()
375 .skip(1)
376 .take(partition_data.len() - 2)
377 .map(|larger| {
378 let mut larger_subtree_nodes =
379 populate_subtree_tensor_map(contraction_tree, larger.id, tensor, height_limit);
380 larger_subtree_nodes.remove(&larger.id);
381 let (rebalanced_node, objective) = find_rebalance_node(
382 random_balance,
383 &larger_subtree_nodes,
384 &smaller_subtree_nodes,
385 objective_function,
386 );
387
388 (larger.id, rebalanced_node, objective)
389 })
390 .max_by(|a, b| a.2.total_cmp(&b.2))
391 .unwrap();
392
393 let rebalanced_leaf_ids = contraction_tree.leaf_ids(second_rebalanced_node);
394 shifts.push(Shift {
395 from_subtree_id: larger_subtree_id,
396 to_subtree_id: smaller_subtree_id,
397 moved_leaf_ids: rebalanced_leaf_ids,
398 });
399 shifts
400}
401
402pub(super) fn intermediate_tensors_odd<R>(
404 partition_data: &[PartitionData],
405 contraction_tree: &ContractionTree,
406 random_balance: &mut Option<(usize, R)>,
407 objective_function: fn(&Tensor, &Tensor) -> f64,
408 tensor: &Tensor,
409 height_limit: Option<usize>,
410) -> Vec<Shift>
411where
412 R: Rng,
413{
414 let larger_subtree_id = partition_data.last().unwrap().id;
416
417 let mut larger_subtree_nodes =
419 populate_subtree_tensor_map(contraction_tree, larger_subtree_id, tensor, height_limit);
420 larger_subtree_nodes.remove(&larger_subtree_id);
421
422 let (smaller_subtree_id, rebalanced_node, _) = partition_data
424 .iter()
425 .take(partition_data.len() - 1)
426 .map(|smaller| {
427 let smaller_subtree_nodes =
428 populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
429 let (rebalanced_node, objective) = find_rebalance_node(
430 random_balance,
431 &larger_subtree_nodes,
432 &smaller_subtree_nodes,
433 objective_function,
434 );
435 (smaller.id, rebalanced_node, objective)
436 })
437 .max_by(|a, b| a.2.total_cmp(&b.2))
438 .unwrap();
439
440 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
441 vec![Shift {
442 from_subtree_id: larger_subtree_id,
443 to_subtree_id: smaller_subtree_id,
444 moved_leaf_ids: rebalanced_leaf_ids,
445 }]
446}
447
448pub(super) fn intermediate_tensors_even<R>(
450 partition_data: &[PartitionData],
451 contraction_tree: &ContractionTree,
452 random_balance: &mut Option<(usize, R)>,
453 objective_function: fn(&Tensor, &Tensor) -> f64,
454 tensor: &Tensor,
455 height_limit: Option<usize>,
456) -> Vec<Shift>
457where
458 R: Rng,
459{
460 let smaller_subtree_id = partition_data.first().unwrap().id;
461
462 let smaller_subtree_nodes =
463 populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
464
465 let (larger_subtree_id, rebalanced_node, _) = partition_data
466 .iter()
467 .skip(1)
468 .filter_map(|larger| {
469 let mut larger_subtree_nodes =
470 populate_subtree_tensor_map(contraction_tree, larger.id, tensor, height_limit);
471 if larger_subtree_nodes.len() == 1 {
472 return None;
473 }
474 larger_subtree_nodes.remove(&larger.id);
475 let (rebalanced_node, objective) = find_rebalance_node(
476 random_balance,
477 &larger_subtree_nodes,
478 &smaller_subtree_nodes,
479 objective_function,
480 );
481
482 Some((larger.id, rebalanced_node, objective))
483 })
484 .max_by(|a, b| a.2.total_cmp(&b.2))
485 .unwrap();
486
487 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
488 vec![Shift {
489 from_subtree_id: larger_subtree_id,
490 to_subtree_id: smaller_subtree_id,
491 moved_leaf_ids: rebalanced_leaf_ids,
492 }]
493}
494
495pub(super) fn tree_tensors_odd(
497 partition_data: &[PartitionData],
498 contraction_tree: &ContractionTree,
499 objective_function: fn(&Tensor, &Tensor) -> f64,
500 tensor: &Tensor,
501 height_limit: usize,
502) -> Vec<Shift> {
503 let larger_subtree_id = partition_data.last().unwrap().id;
505
506 let mut larger_subtree_nodes = populate_subtree_tensor_map(
508 contraction_tree,
509 larger_subtree_id,
510 tensor,
511 Some(height_limit),
512 );
513 larger_subtree_nodes.remove(&larger_subtree_id);
514
515 let (smaller_subtree_id, rebalanced_node, _) = partition_data
517 .iter()
518 .take(partition_data.len() - 1)
519 .map(|smaller| {
520 let PartitionData {
521 local_tensor, id, ..
522 } = smaller;
523 let mut objective = 0.;
524 let mut rebalanced_node = None;
525 for (node_id, node) in &larger_subtree_nodes {
526 let new_obj = objective_function(node, local_tensor);
527 if new_obj > objective {
528 objective = new_obj;
529 rebalanced_node = Some(*node_id);
530 }
531 }
532 (*id, rebalanced_node, objective)
533 })
534 .max_by(|a, b| a.2.total_cmp(&b.2))
535 .unwrap();
536 if let Some(rebalanced_node) = rebalanced_node {
537 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
538 vec![Shift {
539 from_subtree_id: larger_subtree_id,
540 to_subtree_id: smaller_subtree_id,
541 moved_leaf_ids: rebalanced_leaf_ids,
542 }]
543 } else {
544 Vec::new()
545 }
546}
547
548pub(super) fn tree_tensors_even(
550 partition_data: &[PartitionData],
551 contraction_tree: &ContractionTree,
552 objective_function: fn(&Tensor, &Tensor) -> f64,
553 tensor: &Tensor,
554 height_limit: usize,
555) -> Vec<Shift> {
556 let smaller_subtree_id = partition_data.first().unwrap().id;
557
558 let PartitionData {
561 local_tensor: smaller_tensor,
562 ..
563 } = partition_data.first().unwrap();
564
565 let (larger_subtree_id, rebalanced_node, _) = partition_data
566 .iter()
567 .skip(1)
568 .filter_map(|larger| {
569 let mut larger_subtree_nodes = populate_subtree_tensor_map(
570 contraction_tree,
571 larger.id,
572 tensor,
573 Some(height_limit),
574 );
575 if larger_subtree_nodes.len() == 1 {
576 return None;
577 }
578 larger_subtree_nodes.remove(&larger.id);
579 let mut objective = 0.;
580 let mut rebalanced_node = None;
581 for (node_id, node) in &larger_subtree_nodes {
582 let new_obj = objective_function(node, smaller_tensor);
583 if new_obj > objective {
584 objective = new_obj;
585 rebalanced_node = Some(*node_id);
586 }
587 }
588
589 Some((larger.id, rebalanced_node, objective))
590 })
591 .max_by(|a, b| a.2.total_cmp(&b.2))
592 .unwrap();
593
594 if let Some(rebalanced_node) = rebalanced_node {
595 let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
596 vec![Shift {
597 from_subtree_id: larger_subtree_id,
598 to_subtree_id: smaller_subtree_id,
599 moved_leaf_ids: rebalanced_leaf_ids,
600 }]
601 } else {
602 Vec::new()
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 use rand::rngs::StdRng;
611 use rustc_hash::FxHashMap;
612
613 use crate::{
614 contractionpath::contraction_tree::{balancing::PartitionData, ContractionTree},
615 path,
616 tensornetwork::tensor::Tensor,
617 };
618
619 fn setup_simple_partition_data() -> Vec<PartitionData> {
620 let bond_dims = FxHashMap::from_iter([
621 (0, 2),
622 (1, 2),
623 (2, 2),
624 (3, 2),
625 (4, 2),
626 (5, 2),
627 (6, 2),
628 (7, 2),
629 (8, 2),
630 (9, 2),
631 (10, 2),
632 ]);
633 vec![
634 PartitionData {
635 id: 2,
636 flop_cost: 1.,
637 mem_cost: 0.,
638 contraction: Default::default(),
639 local_tensor: Tensor::new_from_map(vec![7, 9, 10], &bond_dims),
640 },
641 PartitionData {
642 id: 7,
643 flop_cost: 2.,
644 mem_cost: 0.,
645 contraction: Default::default(),
646 local_tensor: Tensor::new_from_map(vec![0, 1, 5, 7], &bond_dims),
647 },
648 PartitionData {
649 id: 14,
650 flop_cost: 3.,
651 mem_cost: 0.,
652 contraction: Default::default(),
653 local_tensor: Tensor::new_from_map(vec![0, 1, 2, 5, 10], &bond_dims),
654 },
655 ]
656 }
657
658 fn setup_simple() -> (ContractionTree, Tensor) {
660 let bond_dims = FxHashMap::from_iter([
661 (0, 2),
662 (1, 2),
663 (2, 2),
664 (3, 2),
665 (4, 2),
666 (5, 2),
667 (6, 2),
668 (7, 2),
669 (8, 2),
670 (9, 2),
671 (10, 2),
672 ]);
673
674 let tensor0 = Tensor::new_from_map(vec![7, 8], &bond_dims);
675 let tensor1 = Tensor::new_from_map(vec![8, 9, 10], &bond_dims);
676
677 let tensor3 = Tensor::new_from_map(vec![0, 6], &bond_dims);
678 let tensor4 = Tensor::new_from_map(vec![1, 6], &bond_dims);
679 let tensor5 = Tensor::new_from_map(vec![5, 7], &bond_dims);
680
681 let tensor8 = Tensor::new_from_map(vec![0, 1], &bond_dims);
682 let tensor9 = Tensor::new_from_map(vec![2, 3], &bond_dims);
683 let tensor10 = Tensor::new_from_map(vec![3, 4], &bond_dims);
684 let tensor11 = Tensor::new_from_map(vec![4, 5, 10], &bond_dims);
685
686 let intermediate_tensor2 = Tensor::new_composite(vec![tensor0, tensor1]);
687
688 let intermediate_tensor7 = Tensor::new_composite(vec![tensor3, tensor4, tensor5]);
689
690 let intermediate_tensor14 =
691 Tensor::new_composite(vec![tensor8, tensor9, tensor10, tensor11]);
692
693 let tensor15 = Tensor::new_composite(vec![
694 intermediate_tensor2,
695 intermediate_tensor7,
696 intermediate_tensor14,
697 ]);
698
699 let contraction_path = path![
700 {
701 (0, [(0, 1)]),
702 (1, [(0, 1), (0, 2)]),
703 (2, [(0, 3), (2, 1), (0, 2)]),
704 },
705 (0, 1),
706 (0, 2)
707 ];
708
709 (
710 ContractionTree::from_contraction_path(&tensor15, &contraction_path),
711 tensor15,
712 )
713 }
714
715 fn custom_cost_function(a: &Tensor, b: &Tensor) -> f64 {
716 (a & b).legs().len() as f64
717 }
718
719 #[test]
720 fn test_best_worst_balancing() {
721 let partition_data = setup_simple_partition_data();
722 let (contraction_tree, tensor) = setup_simple();
723
724 let output = best_worst::<StdRng>(
725 &partition_data,
726 &contraction_tree,
727 &mut None,
728 custom_cost_function,
729 &tensor,
730 );
731
732 let ref_output = vec![Shift {
733 from_subtree_id: 14,
734 to_subtree_id: 2,
735 moved_leaf_ids: vec![11],
736 }];
737 assert_eq!(output, ref_output);
738 }
739
740 #[test]
741 fn test_tensor_balancing() {
742 let partition_data = setup_simple_partition_data();
743 let (contraction_tree, tensor) = setup_simple();
744
745 let output = best_tensor::<StdRng>(
746 &partition_data,
747 &contraction_tree,
748 &mut None,
749 custom_cost_function,
750 &tensor,
751 );
752
753 let ref_output = vec![Shift {
754 from_subtree_id: 14,
755 to_subtree_id: 7,
756 moved_leaf_ids: vec![8],
757 }];
758 assert_eq!(output, ref_output);
759 }
760
761 #[test]
762 fn test_tensors_balancing() {
763 let partition_data = setup_simple_partition_data();
764 let (contraction_tree, tensor) = setup_simple();
765
766 let output = best_tensors::<StdRng>(
767 &partition_data,
768 &contraction_tree,
769 &mut None,
770 custom_cost_function,
771 &tensor,
772 );
773
774 let ref_output = vec![
775 Shift {
776 from_subtree_id: 14,
777 to_subtree_id: 7,
778 moved_leaf_ids: vec![8],
779 },
780 Shift {
781 from_subtree_id: 7,
782 to_subtree_id: 2,
783 moved_leaf_ids: vec![5],
784 },
785 ];
786 assert_eq!(output, ref_output);
787 }
788
789 #[test]
790 fn test_alternating_tensors_balancing_odd() {
791 let partition_data = setup_simple_partition_data();
792 let (contraction_tree, tensor) = setup_simple();
793
794 let output = tensors_odd::<StdRng>(
795 &partition_data,
796 &contraction_tree,
797 &mut None,
798 custom_cost_function,
799 &tensor,
800 );
801
802 let ref_output = vec![Shift {
805 from_subtree_id: 14,
806 to_subtree_id: 7,
807 moved_leaf_ids: vec![8],
808 }];
809 assert_eq!(output, ref_output);
810 }
811
812 #[test]
813 fn test_alternating_tensors_balancing_even() {
814 let partition_data = setup_simple_partition_data();
815 let (contraction_tree, tensor) = setup_simple();
816
817 let output = tensors_even::<StdRng>(
818 &partition_data,
819 &contraction_tree,
820 &mut None,
821 custom_cost_function,
822 &tensor,
823 );
824 let ref_output = vec![Shift {
827 from_subtree_id: 14,
828 to_subtree_id: 2,
829 moved_leaf_ids: vec![11],
830 }];
831 assert_eq!(output, ref_output);
832 }
833
834 #[test]
835 fn test_intermediate_tensors_balancing() {
836 let partition_data = setup_simple_partition_data();
837 let (contraction_tree, tensor) = setup_simple();
838
839 let output = best_intermediate_tensors::<StdRng>(
840 &partition_data,
841 &contraction_tree,
842 &mut None,
843 custom_cost_function,
844 &tensor,
845 Some(1),
846 );
847
848 let ref_output = vec![
849 Shift {
850 from_subtree_id: 14,
851 to_subtree_id: 7,
852 moved_leaf_ids: vec![8, 11],
853 },
854 Shift {
855 from_subtree_id: 7,
856 to_subtree_id: 2,
857 moved_leaf_ids: vec![5],
858 },
859 ];
860 assert_eq!(output, ref_output);
861 }
862}