1use core::f64;
4use std::rc::Rc;
5
6use itertools::Itertools;
7use log::debug;
8use rand::seq::IndexedRandom;
9use rand::{rngs::StdRng, Rng};
10use rustc_hash::FxHashMap;
11
12use crate::contractionpath::communication_schemes::CommunicationScheme;
13use crate::contractionpath::contraction_cost::communication_path_op_costs;
14use crate::contractionpath::contraction_tree::{
15 utils::{characterize_partition, subtree_contraction_path},
16 ContractionTree,
17};
18use crate::contractionpath::paths::validate_path;
19use crate::contractionpath::{ContractionPath, SimplePath};
20use crate::tensornetwork::tensor::{Tensor, TensorIndex};
21
22mod balancing_schemes;
23
24pub use balancing_schemes::BalancingScheme;
25
26#[derive(Debug, Clone, Copy)]
27pub struct BalanceSettings<R>
28where
29 R: Rng,
30{
31 random_balance: Option<(usize, R)>,
34 rebalance_depth: usize,
35 iterations: usize,
36 objective_function: fn(&Tensor, &Tensor) -> f64,
37 communication_scheme: CommunicationScheme,
38 balancing_scheme: BalancingScheme,
39 memory_limit: Option<f64>,
40}
41
42impl BalanceSettings<StdRng> {
43 pub fn new(
44 rebalance_depth: usize,
45 iterations: usize,
46 objective_function: fn(&Tensor, &Tensor) -> f64,
47 communication_scheme: CommunicationScheme,
48 balancing_scheme: BalancingScheme,
49 memory_limit: Option<f64>,
50 ) -> Self {
51 BalanceSettings::<StdRng> {
52 random_balance: None,
53 rebalance_depth,
54 iterations,
55 objective_function,
56 communication_scheme,
57 balancing_scheme,
58 memory_limit,
59 }
60 }
61}
62
63impl<R> BalanceSettings<R>
64where
65 R: Rng,
66{
67 pub fn new_random(
68 random_balance: Option<(usize, R)>,
69 rebalance_depth: usize,
70 iterations: usize,
71 objective_function: fn(&Tensor, &Tensor) -> f64,
72 communication_scheme: CommunicationScheme,
73 balancing_scheme: BalancingScheme,
74 memory_limit: Option<f64>,
75 ) -> Self {
76 BalanceSettings::<R> {
77 random_balance,
78 rebalance_depth,
79 iterations,
80 objective_function,
81 communication_scheme,
82 balancing_scheme,
83 memory_limit,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
89pub(crate) struct PartitionData {
90 pub id: usize,
91 pub flop_cost: f64,
92 pub mem_cost: f64,
93 pub contraction: SimplePath,
94 pub local_tensor: Tensor,
95}
96
97pub fn balance_partitions_iter<R>(
99 tensor_network: &Tensor,
100 path: &ContractionPath,
101 mut balance_settings: BalanceSettings<R>,
102 rng: &mut R,
103) -> (usize, Tensor, ContractionPath, Vec<(f64, f64)>)
104where
105 R: Rng,
106{
107 let mut contraction_tree = ContractionTree::from_contraction_path(tensor_network, path);
108
109 let BalanceSettings {
110 rebalance_depth,
111 iterations,
112 balancing_scheme,
113 communication_scheme,
114 memory_limit,
115 ..
116 } = balance_settings;
117 let mut partition_data =
118 characterize_partition(&contraction_tree, rebalance_depth, tensor_network);
119
120 assert!(partition_data.len() > 1);
121
122 let partition_number = partition_data.len();
123
124 let (partition_tensors, partition_costs): (Vec<_>, Vec<_>) = partition_data
125 .iter()
126 .map(
127 |PartitionData {
128 local_tensor,
129 flop_cost,
130 ..
131 }| (local_tensor.clone(), *flop_cost),
132 )
133 .collect();
134
135 let ((mut best_cost, sum_cost), _) = communication_path_op_costs(
136 &partition_tensors,
137 &path.toplevel,
138 true,
139 Some(&partition_costs),
140 );
141
142 let mut max_costs = Vec::with_capacity(iterations + 1);
143 max_costs.push((best_cost, sum_cost));
144
145 let mut best_iteration = 0;
146 let mut best_contraction_path = path.to_owned();
147 let mut best_tn = tensor_network.clone();
148
149 for iteration in 1..=iterations {
150 debug!("Balancing iteration {iteration} with balancing scheme {balancing_scheme:?}, communication scheme {communication_scheme:?}");
151
152 let (nested_paths, new_tensor_network) = balance_partitions(
154 &mut partition_data,
155 &mut contraction_tree,
156 tensor_network,
157 &mut balance_settings,
158 iteration,
159 );
160
161 assert_eq!(nested_paths.len(), partition_number, "Tensors lost!");
162
163 let communication_path = communicate_partitions(
166 &partition_data,
167 &mut contraction_tree,
168 &new_tensor_network,
169 &balance_settings,
170 Some(rng),
171 );
172
173 let (partition_tensors, partition_costs): (Vec<_>, Vec<_>) = partition_data
174 .iter()
175 .map(
176 |PartitionData {
177 local_tensor,
178 flop_cost,
179 ..
180 }| (local_tensor.clone(), *flop_cost),
181 )
182 .collect();
183
184 let ((flop_cost, sum_cost), mem_cost) = communication_path_op_costs(
185 &partition_tensors,
186 &communication_path,
187 true,
188 Some(&partition_costs),
189 );
190
191 let new_path = ContractionPath {
192 nested: nested_paths,
193 toplevel: path.toplevel.clone(),
194 };
195 validate_path(&new_path);
196
197 max_costs.push((flop_cost, sum_cost));
198 if memory_limit.is_some_and(|limit| mem_cost > limit) {
199 break;
200 }
201 if flop_cost < best_cost {
202 best_cost = flop_cost;
203 best_iteration = iteration;
204 best_tn = new_tensor_network;
205 best_contraction_path = new_path;
206 }
207 }
208
209 (best_iteration, best_tn, best_contraction_path, max_costs)
210}
211
212fn communicate_partitions<R>(
213 partition_data: &[PartitionData],
214 contraction_tree: &mut ContractionTree,
215 tensor_network: &Tensor,
216 balance_settings: &BalanceSettings<R>,
217 rng: Option<&mut R>,
218) -> SimplePath
219where
220 R: Rng,
221{
222 let communication_scheme = balance_settings.communication_scheme;
223 let children_tensors = tensor_network
224 .tensors()
225 .iter()
226 .map(Tensor::external_tensor)
227 .collect_vec();
228 let latency_map = partition_data
229 .iter()
230 .enumerate()
231 .map(|(i, partition)| (i, partition.flop_cost))
232 .collect::<FxHashMap<_, _>>();
233
234 let partition_ids = partition_data
235 .iter()
236 .map(|partition| partition.id)
237 .collect_vec();
238 let communication_path =
239 communication_scheme.communication_path(&children_tensors, &latency_map, rng);
240
241 contraction_tree.replace_communication_path(partition_ids, &communication_path);
242
243 communication_path
244}
245
246fn balance_partitions<R>(
247 partition_data: &mut [PartitionData],
248 contraction_tree: &mut ContractionTree,
249 tensor_network: &Tensor,
250 balance_settings: &mut BalanceSettings<R>,
251 iteration: usize,
252) -> (FxHashMap<TensorIndex, ContractionPath>, Tensor)
253where
254 R: Rng,
255{
256 let BalanceSettings {
257 ref mut random_balance,
258 rebalance_depth,
259 objective_function,
260 balancing_scheme,
261 ..
262 } = balance_settings;
263 if tensor_network.total_num_tensors() < 3 {
265 panic!("No rebalancing undertaken, as tn is too small (< 3 tensors)");
267 }
268 assert!(partition_data.len() > 1);
270
271 partition_data.sort_unstable_by(|a, b| a.flop_cost.total_cmp(&b.flop_cost));
272
273 let shifted_nodes = match balancing_scheme {
274 BalancingScheme::BestWorst => balancing_schemes::best_worst(
275 partition_data,
276 contraction_tree,
277 random_balance,
278 *objective_function,
279 tensor_network,
280 ),
281 BalancingScheme::Tensor => balancing_schemes::best_tensor(
282 partition_data,
283 contraction_tree,
284 random_balance,
285 *objective_function,
286 tensor_network,
287 ),
288 BalancingScheme::Tensors => balancing_schemes::best_tensors(
289 partition_data,
290 contraction_tree,
291 random_balance,
292 *objective_function,
293 tensor_network,
294 ),
295 BalancingScheme::AlternatingTensors => {
296 if iteration % 2 == 1 {
297 balancing_schemes::tensors_odd(
298 partition_data,
299 contraction_tree,
300 random_balance,
301 *objective_function,
302 tensor_network,
303 )
304 } else {
305 balancing_schemes::tensors_even(
306 partition_data,
307 contraction_tree,
308 random_balance,
309 *objective_function,
310 tensor_network,
311 )
312 }
313 }
314 BalancingScheme::IntermediateTensors { height_limit } => {
315 balancing_schemes::best_intermediate_tensors(
316 partition_data,
317 contraction_tree,
318 random_balance,
319 *objective_function,
320 tensor_network,
321 *height_limit,
322 )
323 }
324 BalancingScheme::AlternatingIntermediateTensors { height_limit } => {
325 if iteration % 2 == 1 {
326 balancing_schemes::intermediate_tensors_odd(
327 partition_data,
328 contraction_tree,
329 random_balance,
330 *objective_function,
331 tensor_network,
332 *height_limit,
333 )
334 } else {
335 balancing_schemes::intermediate_tensors_even(
336 partition_data,
337 contraction_tree,
338 random_balance,
339 *objective_function,
340 tensor_network,
341 *height_limit,
342 )
343 }
344 }
345 BalancingScheme::AlternatingTreeTensors { height_limit } => {
346 if iteration % 2 == 1 {
347 balancing_schemes::tree_tensors_odd(
348 partition_data,
349 contraction_tree,
350 *objective_function,
351 tensor_network,
352 *height_limit,
353 )
354 } else {
355 balancing_schemes::tree_tensors_even(
356 partition_data,
357 contraction_tree,
358 *objective_function,
359 tensor_network,
360 *height_limit,
361 )
362 }
363 }
364 };
365 let mut shifted_indices = FxHashMap::default();
366 for shift in shifted_nodes {
367 let shifted_from_id = *shifted_indices
368 .get(&shift.from_subtree_id)
369 .unwrap_or(&shift.from_subtree_id);
370
371 let shifted_to_id = *shifted_indices
372 .get(&shift.to_subtree_id)
373 .unwrap_or(&shift.to_subtree_id);
374
375 let (
376 larger_id,
377 larger_contraction,
378 larger_subtree_flop_cost,
379 larger_subtree_mem_cost,
380 smaller_id,
381 smaller_contraction,
382 smaller_subtree_flop_cost,
383 smaller_subtree_mem_cost,
384 ) = shift_node_between_subtrees(
385 contraction_tree,
386 *rebalance_depth,
387 shifted_from_id,
388 shifted_to_id,
389 shift.moved_leaf_ids,
390 tensor_network,
391 );
392 shifted_indices.insert(shift.from_subtree_id, larger_id);
393 shifted_indices.insert(shift.to_subtree_id, smaller_id);
394
395 let larger_tensor = contraction_tree.tensor(larger_id, tensor_network);
396 let smaller_tensor = contraction_tree.tensor(smaller_id, tensor_network);
397
398 for PartitionData {
400 id,
401 flop_cost,
402 mem_cost,
403 contraction: subtree_contraction,
404 local_tensor,
405 } in partition_data.iter_mut()
406 {
407 if *id == shifted_from_id {
408 *id = larger_id;
409 subtree_contraction.clone_from(&larger_contraction);
410 *flop_cost = larger_subtree_flop_cost;
411 *mem_cost = larger_subtree_mem_cost;
412 *local_tensor = larger_tensor.clone();
413 } else if *id == shifted_to_id {
414 *id = smaller_id;
415 subtree_contraction.clone_from(&smaller_contraction);
416 *flop_cost = smaller_subtree_flop_cost;
417 *mem_cost = smaller_subtree_mem_cost;
418 *local_tensor = smaller_tensor.clone();
419 }
420 }
421 }
422
423 partition_data.sort_unstable_by(
424 |PartitionData {
425 flop_cost: cost_a, ..
426 },
427 PartitionData {
428 flop_cost: cost_b, ..
429 }| { cost_a.total_cmp(cost_b) },
430 );
431
432 let mut rebalanced_paths = FxHashMap::default();
433 let (partition_tensors, partition_ids): (Vec<_>, Vec<_>) = partition_data
434 .iter()
435 .enumerate()
436 .map(
437 |(
438 i,
439 PartitionData {
440 id,
441 contraction: subtree_contraction,
442 ..
443 },
444 )| {
445 rebalanced_paths.insert(i, ContractionPath::simple(subtree_contraction.clone()));
446 let mut child_tensor = Tensor::default();
447 let leaf_ids = contraction_tree.leaf_ids(*id);
448 let leaf_tensors = leaf_ids
449 .iter()
450 .map(|node_id| {
451 let nested_indices = contraction_tree
452 .node(*node_id)
453 .tensor_index()
454 .cloned()
455 .unwrap();
456 tensor_network.nested_tensor(&nested_indices).clone()
457 })
458 .collect_vec();
459 child_tensor.push_tensors(leaf_tensors);
460 (child_tensor, *id)
461 },
462 )
463 .collect();
464
465 contraction_tree
466 .partitions
467 .insert(*rebalance_depth, partition_ids);
468
469 let mut updated_tn = Tensor::default();
470 updated_tn.push_tensors(partition_tensors);
471 (rebalanced_paths, updated_tn)
472}
473
474fn find_rebalance_node<R>(
482 random_balance: &mut Option<(usize, R)>,
483 larger_subtree_nodes: &FxHashMap<usize, Tensor>,
484 smaller_subtree_nodes: &FxHashMap<usize, Tensor>,
485 objective_function: fn(&Tensor, &Tensor) -> f64,
486) -> (usize, f64)
487where
488 R: Rng,
489{
490 let node_comparison = larger_subtree_nodes
491 .iter()
492 .cartesian_product(smaller_subtree_nodes.iter())
493 .map(|((larger_node_id, larger_tensor), (_, smaller_tensor))| {
494 (
495 *larger_node_id,
496 objective_function(larger_tensor, smaller_tensor),
497 )
498 });
499 if let Some((options_considered, ref mut rng)) = random_balance {
500 let node_options = node_comparison
501 .sorted_unstable_by(|a, b| b.1.total_cmp(&a.1))
502 .take(*options_considered)
503 .collect_vec();
504 let max = node_options.first().unwrap().1;
505 *node_options
507 .choose_weighted(rng, |node_option| node_option.1 / max)
508 .unwrap()
509 } else {
510 node_comparison.max_by(|a, b| a.1.total_cmp(&b.1)).unwrap()
511 }
512}
513
514fn shift_node_between_subtrees(
517 contraction_tree: &mut ContractionTree,
518 rebalance_depth: usize,
519 larger_subtree_id: usize,
520 smaller_subtree_id: usize,
521 rebalanced_nodes: Vec<usize>,
522 tensor_network: &Tensor,
523) -> (usize, SimplePath, f64, f64, usize, SimplePath, f64, f64) {
524 let larger_subtree_parent_id = contraction_tree
526 .node(larger_subtree_id)
527 .parent_id()
528 .unwrap();
529 let smaller_subtree_parent_id = contraction_tree
530 .node(smaller_subtree_id)
531 .parent_id()
532 .unwrap();
533
534 let mut larger_subtree_leaf_nodes = contraction_tree.leaf_ids(larger_subtree_id);
535 let mut smaller_subtree_leaf_nodes = contraction_tree.leaf_ids(smaller_subtree_id);
536
537 assert!(rebalanced_nodes
539 .iter()
540 .all(|node| !smaller_subtree_leaf_nodes.contains(node)));
541 assert!(rebalanced_nodes
542 .iter()
543 .all(|node| larger_subtree_leaf_nodes.contains(node)));
544
545 larger_subtree_leaf_nodes.retain(|leaf| !rebalanced_nodes.contains(leaf));
547 smaller_subtree_leaf_nodes.extend(rebalanced_nodes);
548
549 let (updated_larger_path, local_larger_path, larger_flop_cost, larger_mem_cost) =
551 subtree_contraction_path(&larger_subtree_leaf_nodes, contraction_tree, tensor_network);
552
553 let (updated_smaller_path, local_smaller_path, smaller_flop_cost, smaller_mem_cost) =
554 subtree_contraction_path(
555 &smaller_subtree_leaf_nodes,
556 contraction_tree,
557 tensor_network,
558 );
559
560 contraction_tree.remove_subtree(larger_subtree_id);
562
563 let new_larger_subtree_id = if updated_larger_path.is_empty() {
564 contraction_tree.nodes[&larger_subtree_leaf_nodes[0]]
566 .borrow_mut()
567 .set_parent(Rc::downgrade(
568 &contraction_tree.nodes[&larger_subtree_parent_id],
569 ));
570 contraction_tree.nodes[&larger_subtree_parent_id]
571 .borrow_mut()
572 .add_child(Rc::downgrade(
573 &contraction_tree.nodes[&larger_subtree_leaf_nodes[0]],
574 ));
575 larger_subtree_leaf_nodes[0]
576 } else {
577 contraction_tree.add_path_as_subtree(
578 &ContractionPath::simple(updated_larger_path),
579 larger_subtree_parent_id,
580 &larger_subtree_leaf_nodes,
581 )
582 };
583
584 contraction_tree.remove_subtree(smaller_subtree_id);
586 let new_smaller_subtree_id = contraction_tree.add_path_as_subtree(
588 &ContractionPath::simple(updated_smaller_path),
589 smaller_subtree_parent_id,
590 &smaller_subtree_leaf_nodes,
591 );
592
593 let partition = contraction_tree
595 .partitions
596 .get_mut(&rebalance_depth)
597 .unwrap();
598 partition.retain(|&e| e != smaller_subtree_id && e != larger_subtree_id);
599 partition.push(new_larger_subtree_id);
600 partition.push(new_smaller_subtree_id);
601
602 (
603 new_larger_subtree_id,
604 local_larger_path,
605 larger_flop_cost,
606 larger_mem_cost,
607 new_smaller_subtree_id,
608 local_smaller_path,
609 smaller_flop_cost,
610 smaller_mem_cost,
611 )
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 use std::rc::Rc;
619
620 use rand::{rngs::StdRng, SeedableRng};
621 use rustc_hash::FxHashMap;
622
623 use crate::contractionpath::contraction_tree::{
624 balancing::find_rebalance_node,
625 node::{child_node, parent_node},
626 ContractionTree,
627 };
628 use crate::path;
629 use crate::tensornetwork::tensor::Tensor;
630
631 fn setup_complex() -> (ContractionTree, Tensor) {
632 let bond_dims = FxHashMap::from_iter([
633 (0, 27),
634 (1, 18),
635 (2, 12),
636 (3, 15),
637 (4, 5),
638 (5, 3),
639 (6, 18),
640 (7, 22),
641 (8, 45),
642 (9, 65),
643 (10, 5),
644 ]);
645 let (tensor, contraction_path) = (
646 Tensor::new_composite(vec![
647 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
648 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
649 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
650 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
651 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
652 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
653 ]),
654 path![(1, 5), (0, 1), (3, 4), (2, 3), (0, 2)],
655 );
656 (
657 ContractionTree::from_contraction_path(&tensor, &contraction_path),
658 tensor,
659 )
660 }
661
662 #[test]
663 fn test_shift_leaf_node_between_subtrees() {
664 let (mut tree, tensor) = setup_complex();
665 tree.partitions.entry(1).or_insert_with(|| vec![9, 7]);
666 shift_node_between_subtrees(&mut tree, 1, 9, 7, vec![3], &tensor);
667
668 let ContractionTree { nodes, root, .. } = tree;
669
670 let node0 = child_node(0, vec![0]);
671 let node1 = child_node(1, vec![1]);
672 let node2 = child_node(2, vec![2]);
673 let node3 = child_node(3, vec![3]);
674 let node4 = child_node(4, vec![4]);
675 let node5 = child_node(5, vec![5]);
676
677 let node6 = parent_node(6, &node1, &node5);
678 let node7 = parent_node(7, &node6, &node0);
679 let node8 = parent_node(8, &node2, &node4);
680 let node9 = parent_node(9, &node7, &node3);
681 let node10 = parent_node(10, &node9, &node8);
682
683 let ref_root = Rc::clone(&node10);
684 let ref_nodes = [
685 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
686 ];
687
688 for (key, ref_node) in ref_nodes.iter().enumerate() {
689 let node = &nodes[&key];
690 assert_eq!(node, ref_node);
691 }
692 assert_eq!(root.upgrade().unwrap(), ref_root);
693 }
694
695 #[test]
696 fn test_shift_subtree_between_subtrees() {
697 let (mut tree, tensor) = setup_complex();
698 tree.partitions.entry(1).or_insert_with(|| vec![9, 7]);
699 shift_node_between_subtrees(&mut tree, 1, 9, 7, vec![2, 3], &tensor);
700
701 let ContractionTree { nodes, root, .. } = tree;
702
703 let node0 = child_node(0, vec![0]);
704 let node1 = child_node(1, vec![1]);
705 let node2 = child_node(2, vec![2]);
706 let node3 = child_node(3, vec![3]);
707 let node4 = child_node(4, vec![4]);
708 let node5 = child_node(5, vec![5]);
709
710 let node6 = parent_node(6, &node1, &node5);
711 let node7 = parent_node(7, &node2, &node3);
712 let node8 = parent_node(8, &node6, &node0);
713 let node9 = parent_node(9, &node8, &node7);
714 let node10 = parent_node(10, &node9, &node4);
715
716 let ref_root = Rc::clone(&node10);
717 let ref_nodes = [
718 node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
719 ];
720
721 for (key, ref_node) in ref_nodes.iter().enumerate() {
722 let node = &nodes[&key];
723 assert_eq!(node, ref_node);
724 }
725
726 assert_eq!(root.upgrade().unwrap(), ref_root);
727 }
728
729 fn custom_weight_function(a: &Tensor, b: &Tensor) -> f64 {
730 (a & b).legs().len() as f64
731 }
732
733 #[test]
734 fn test_find_rebalance_node() {
735 let bond_dims =
736 FxHashMap::from_iter([(0, 2), (1, 1), (2, 3), (3, 5), (4, 3), (5, 8), (6, 7)]);
737 let larger_hash = FxHashMap::from_iter([
738 (0, Tensor::new_from_map(vec![0, 1, 2], &bond_dims)),
739 (1, Tensor::new_from_map(vec![1, 2, 3], &bond_dims)),
740 (2, Tensor::new_from_map(vec![3, 4, 5], &bond_dims)),
741 ]);
742
743 let smaller_hash =
744 FxHashMap::from_iter([(3, Tensor::new_from_map(vec![4, 5, 6], &bond_dims))]);
745
746 let ref_balanced_node = 2;
747 let (node_id, cost) = find_rebalance_node::<StdRng>(
748 &mut None,
749 &larger_hash,
750 &smaller_hash,
751 custom_weight_function,
752 );
753 assert_eq!(cost, 2.);
754 assert_eq!(node_id, ref_balanced_node);
755 }
756
757 #[test]
758 fn test_find_random_rebalance_node() {
759 let bond_dims =
760 FxHashMap::from_iter([(0, 2), (1, 1), (2, 3), (3, 5), (4, 3), (5, 8), (6, 7)]);
761 let larger_hash = FxHashMap::from_iter([
762 (0, Tensor::new_from_map(vec![0, 1, 2], &bond_dims)),
763 (1, Tensor::new_from_map(vec![1, 2, 6], &bond_dims)),
764 (2, Tensor::new_from_map(vec![3, 4, 5], &bond_dims)),
765 ]);
766
767 let smaller_hash =
768 FxHashMap::from_iter([(3, Tensor::new_from_map(vec![4, 5, 6], &bond_dims))]);
769
770 let ref_balanced_node = 1;
771 let (node_id, cost) = find_rebalance_node(
772 &mut Some((2, StdRng::seed_from_u64(1))),
773 &larger_hash,
774 &smaller_hash,
775 custom_weight_function,
776 );
777 assert_eq!(cost, 1.);
778 assert_eq!(node_id, ref_balanced_node);
779 }
780}