tnc/contractionpath/contraction_tree/
balancing.rs

1//! Functionality for greedy balancing of tensor network partitions.
2
3use core::f64;
4use std::rc::Rc;
5
6use itertools::Itertools;
7use log::debug;
8use rand::seq::IndexedRandom;
9use rand::{rngs::StdRng, Rng};
10use rustc_hash::FxHashMap;
11
12use crate::contractionpath::communication_schemes::CommunicationScheme;
13use crate::contractionpath::contraction_cost::communication_path_op_costs;
14use crate::contractionpath::contraction_tree::{
15    utils::{characterize_partition, subtree_contraction_path},
16    ContractionTree,
17};
18use crate::contractionpath::paths::validate_path;
19use crate::contractionpath::{ContractionPath, SimplePath};
20use crate::tensornetwork::tensor::{Tensor, TensorIndex};
21
22mod balancing_schemes;
23
24pub use balancing_schemes::BalancingScheme;
25
26#[derive(Debug, Clone, Copy)]
27pub struct BalanceSettings<R>
28where
29    R: Rng,
30{
31    /// If not None, randomly chooses from top `usize` options. Random choice is
32    /// weighted by objective outcome.
33    random_balance: Option<(usize, R)>,
34    rebalance_depth: usize,
35    iterations: usize,
36    objective_function: fn(&Tensor, &Tensor) -> f64,
37    communication_scheme: CommunicationScheme,
38    balancing_scheme: BalancingScheme,
39    memory_limit: Option<f64>,
40}
41
42impl BalanceSettings<StdRng> {
43    pub fn new(
44        rebalance_depth: usize,
45        iterations: usize,
46        objective_function: fn(&Tensor, &Tensor) -> f64,
47        communication_scheme: CommunicationScheme,
48        balancing_scheme: BalancingScheme,
49        memory_limit: Option<f64>,
50    ) -> Self {
51        BalanceSettings::<StdRng> {
52            random_balance: None,
53            rebalance_depth,
54            iterations,
55            objective_function,
56            communication_scheme,
57            balancing_scheme,
58            memory_limit,
59        }
60    }
61}
62
63impl<R> BalanceSettings<R>
64where
65    R: Rng,
66{
67    pub fn new_random(
68        random_balance: Option<(usize, R)>,
69        rebalance_depth: usize,
70        iterations: usize,
71        objective_function: fn(&Tensor, &Tensor) -> f64,
72        communication_scheme: CommunicationScheme,
73        balancing_scheme: BalancingScheme,
74        memory_limit: Option<f64>,
75    ) -> Self {
76        BalanceSettings::<R> {
77            random_balance,
78            rebalance_depth,
79            iterations,
80            objective_function,
81            communication_scheme,
82            balancing_scheme,
83            memory_limit,
84        }
85    }
86}
87
88#[derive(Debug, Clone)]
89pub(crate) struct PartitionData {
90    pub id: usize,
91    pub flop_cost: f64,
92    pub mem_cost: f64,
93    pub contraction: SimplePath,
94    pub local_tensor: Tensor,
95}
96
97/// Balances a partitioned tensor network to greedily optimize for a given objective.
98pub fn balance_partitions_iter<R>(
99    tensor_network: &Tensor,
100    path: &ContractionPath,
101    mut balance_settings: BalanceSettings<R>,
102    rng: &mut R,
103) -> (usize, Tensor, ContractionPath, Vec<(f64, f64)>)
104where
105    R: Rng,
106{
107    let mut contraction_tree = ContractionTree::from_contraction_path(tensor_network, path);
108
109    let BalanceSettings {
110        rebalance_depth,
111        iterations,
112        balancing_scheme,
113        communication_scheme,
114        memory_limit,
115        ..
116    } = balance_settings;
117    let mut partition_data =
118        characterize_partition(&contraction_tree, rebalance_depth, tensor_network);
119
120    assert!(partition_data.len() > 1);
121
122    let partition_number = partition_data.len();
123
124    let (partition_tensors, partition_costs): (Vec<_>, Vec<_>) = partition_data
125        .iter()
126        .map(
127            |PartitionData {
128                 local_tensor,
129                 flop_cost,
130                 ..
131             }| (local_tensor.clone(), *flop_cost),
132        )
133        .collect();
134
135    let ((mut best_cost, sum_cost), _) = communication_path_op_costs(
136        &partition_tensors,
137        &path.toplevel,
138        true,
139        Some(&partition_costs),
140    );
141
142    let mut max_costs = Vec::with_capacity(iterations + 1);
143    max_costs.push((best_cost, sum_cost));
144
145    let mut best_iteration = 0;
146    let mut best_contraction_path = path.to_owned();
147    let mut best_tn = tensor_network.clone();
148
149    for iteration in 1..=iterations {
150        debug!("Balancing iteration {iteration} with balancing scheme {balancing_scheme:?}, communication scheme {communication_scheme:?}");
151
152        // Balances and updates partitions
153        let (nested_paths, new_tensor_network) = balance_partitions(
154            &mut partition_data,
155            &mut contraction_tree,
156            tensor_network,
157            &mut balance_settings,
158            iteration,
159        );
160
161        assert_eq!(nested_paths.len(), partition_number, "Tensors lost!");
162
163        // Ensures that children tensors are mapped to their respective partition costs
164        // Communication costs include intermediate costs
165        let communication_path = communicate_partitions(
166            &partition_data,
167            &mut contraction_tree,
168            &new_tensor_network,
169            &balance_settings,
170            Some(rng),
171        );
172
173        let (partition_tensors, partition_costs): (Vec<_>, Vec<_>) = partition_data
174            .iter()
175            .map(
176                |PartitionData {
177                     local_tensor,
178                     flop_cost,
179                     ..
180                 }| (local_tensor.clone(), *flop_cost),
181            )
182            .collect();
183
184        let ((flop_cost, sum_cost), mem_cost) = communication_path_op_costs(
185            &partition_tensors,
186            &communication_path,
187            true,
188            Some(&partition_costs),
189        );
190
191        let new_path = ContractionPath {
192            nested: nested_paths,
193            toplevel: path.toplevel.clone(),
194        };
195        validate_path(&new_path);
196
197        max_costs.push((flop_cost, sum_cost));
198        if memory_limit.is_some_and(|limit| mem_cost > limit) {
199            break;
200        }
201        if flop_cost < best_cost {
202            best_cost = flop_cost;
203            best_iteration = iteration;
204            best_tn = new_tensor_network;
205            best_contraction_path = new_path;
206        }
207    }
208
209    (best_iteration, best_tn, best_contraction_path, max_costs)
210}
211
212fn communicate_partitions<R>(
213    partition_data: &[PartitionData],
214    contraction_tree: &mut ContractionTree,
215    tensor_network: &Tensor,
216    balance_settings: &BalanceSettings<R>,
217    rng: Option<&mut R>,
218) -> SimplePath
219where
220    R: Rng,
221{
222    let communication_scheme = balance_settings.communication_scheme;
223    let children_tensors = tensor_network
224        .tensors()
225        .iter()
226        .map(Tensor::external_tensor)
227        .collect_vec();
228    let latency_map = partition_data
229        .iter()
230        .enumerate()
231        .map(|(i, partition)| (i, partition.flop_cost))
232        .collect::<FxHashMap<_, _>>();
233
234    let partition_ids = partition_data
235        .iter()
236        .map(|partition| partition.id)
237        .collect_vec();
238    let communication_path =
239        communication_scheme.communication_path(&children_tensors, &latency_map, rng);
240
241    contraction_tree.replace_communication_path(partition_ids, &communication_path);
242
243    communication_path
244}
245
246fn balance_partitions<R>(
247    partition_data: &mut [PartitionData],
248    contraction_tree: &mut ContractionTree,
249    tensor_network: &Tensor,
250    balance_settings: &mut BalanceSettings<R>,
251    iteration: usize,
252) -> (FxHashMap<TensorIndex, ContractionPath>, Tensor)
253where
254    R: Rng,
255{
256    let BalanceSettings {
257        ref mut random_balance,
258        rebalance_depth,
259        objective_function,
260        balancing_scheme,
261        ..
262    } = balance_settings;
263    // If there are less than 3 tensors in the tn, rebalancing will not make sense.
264    if tensor_network.total_num_tensors() < 3 {
265        // TODO: should not panic, but handle gracefully
266        panic!("No rebalancing undertaken, as tn is too small (< 3 tensors)");
267    }
268    // Will cause strange errors (picking of same partition multiple times if this is not true.Better to panic here.)
269    assert!(partition_data.len() > 1);
270
271    partition_data.sort_unstable_by(|a, b| a.flop_cost.total_cmp(&b.flop_cost));
272
273    let shifted_nodes = match balancing_scheme {
274        BalancingScheme::BestWorst => balancing_schemes::best_worst(
275            partition_data,
276            contraction_tree,
277            random_balance,
278            *objective_function,
279            tensor_network,
280        ),
281        BalancingScheme::Tensor => balancing_schemes::best_tensor(
282            partition_data,
283            contraction_tree,
284            random_balance,
285            *objective_function,
286            tensor_network,
287        ),
288        BalancingScheme::Tensors => balancing_schemes::best_tensors(
289            partition_data,
290            contraction_tree,
291            random_balance,
292            *objective_function,
293            tensor_network,
294        ),
295        BalancingScheme::AlternatingTensors => {
296            if iteration % 2 == 1 {
297                balancing_schemes::tensors_odd(
298                    partition_data,
299                    contraction_tree,
300                    random_balance,
301                    *objective_function,
302                    tensor_network,
303                )
304            } else {
305                balancing_schemes::tensors_even(
306                    partition_data,
307                    contraction_tree,
308                    random_balance,
309                    *objective_function,
310                    tensor_network,
311                )
312            }
313        }
314        BalancingScheme::IntermediateTensors { height_limit } => {
315            balancing_schemes::best_intermediate_tensors(
316                partition_data,
317                contraction_tree,
318                random_balance,
319                *objective_function,
320                tensor_network,
321                *height_limit,
322            )
323        }
324        BalancingScheme::AlternatingIntermediateTensors { height_limit } => {
325            if iteration % 2 == 1 {
326                balancing_schemes::intermediate_tensors_odd(
327                    partition_data,
328                    contraction_tree,
329                    random_balance,
330                    *objective_function,
331                    tensor_network,
332                    *height_limit,
333                )
334            } else {
335                balancing_schemes::intermediate_tensors_even(
336                    partition_data,
337                    contraction_tree,
338                    random_balance,
339                    *objective_function,
340                    tensor_network,
341                    *height_limit,
342                )
343            }
344        }
345        BalancingScheme::AlternatingTreeTensors { height_limit } => {
346            if iteration % 2 == 1 {
347                balancing_schemes::tree_tensors_odd(
348                    partition_data,
349                    contraction_tree,
350                    *objective_function,
351                    tensor_network,
352                    *height_limit,
353                )
354            } else {
355                balancing_schemes::tree_tensors_even(
356                    partition_data,
357                    contraction_tree,
358                    *objective_function,
359                    tensor_network,
360                    *height_limit,
361                )
362            }
363        }
364    };
365    let mut shifted_indices = FxHashMap::default();
366    for shift in shifted_nodes {
367        let shifted_from_id = *shifted_indices
368            .get(&shift.from_subtree_id)
369            .unwrap_or(&shift.from_subtree_id);
370
371        let shifted_to_id = *shifted_indices
372            .get(&shift.to_subtree_id)
373            .unwrap_or(&shift.to_subtree_id);
374
375        let (
376            larger_id,
377            larger_contraction,
378            larger_subtree_flop_cost,
379            larger_subtree_mem_cost,
380            smaller_id,
381            smaller_contraction,
382            smaller_subtree_flop_cost,
383            smaller_subtree_mem_cost,
384        ) = shift_node_between_subtrees(
385            contraction_tree,
386            *rebalance_depth,
387            shifted_from_id,
388            shifted_to_id,
389            shift.moved_leaf_ids,
390            tensor_network,
391        );
392        shifted_indices.insert(shift.from_subtree_id, larger_id);
393        shifted_indices.insert(shift.to_subtree_id, smaller_id);
394
395        let larger_tensor = contraction_tree.tensor(larger_id, tensor_network);
396        let smaller_tensor = contraction_tree.tensor(smaller_id, tensor_network);
397
398        // Update partition data based on shift
399        for PartitionData {
400            id,
401            flop_cost,
402            mem_cost,
403            contraction: subtree_contraction,
404            local_tensor,
405        } in partition_data.iter_mut()
406        {
407            if *id == shifted_from_id {
408                *id = larger_id;
409                subtree_contraction.clone_from(&larger_contraction);
410                *flop_cost = larger_subtree_flop_cost;
411                *mem_cost = larger_subtree_mem_cost;
412                *local_tensor = larger_tensor.clone();
413            } else if *id == shifted_to_id {
414                *id = smaller_id;
415                subtree_contraction.clone_from(&smaller_contraction);
416                *flop_cost = smaller_subtree_flop_cost;
417                *mem_cost = smaller_subtree_mem_cost;
418                *local_tensor = smaller_tensor.clone();
419            }
420        }
421    }
422
423    partition_data.sort_unstable_by(
424        |PartitionData {
425             flop_cost: cost_a, ..
426         },
427         PartitionData {
428             flop_cost: cost_b, ..
429         }| { cost_a.total_cmp(cost_b) },
430    );
431
432    let mut rebalanced_paths = FxHashMap::default();
433    let (partition_tensors, partition_ids): (Vec<_>, Vec<_>) = partition_data
434        .iter()
435        .enumerate()
436        .map(
437            |(
438                i,
439                PartitionData {
440                    id,
441                    contraction: subtree_contraction,
442                    ..
443                },
444            )| {
445                rebalanced_paths.insert(i, ContractionPath::simple(subtree_contraction.clone()));
446                let mut child_tensor = Tensor::default();
447                let leaf_ids = contraction_tree.leaf_ids(*id);
448                let leaf_tensors = leaf_ids
449                    .iter()
450                    .map(|node_id| {
451                        let nested_indices = contraction_tree
452                            .node(*node_id)
453                            .tensor_index()
454                            .cloned()
455                            .unwrap();
456                        tensor_network.nested_tensor(&nested_indices).clone()
457                    })
458                    .collect_vec();
459                child_tensor.push_tensors(leaf_tensors);
460                (child_tensor, *id)
461            },
462        )
463        .collect();
464
465    contraction_tree
466        .partitions
467        .insert(*rebalance_depth, partition_ids);
468
469    let mut updated_tn = Tensor::default();
470    updated_tn.push_tensors(partition_tensors);
471    (rebalanced_paths, updated_tn)
472}
473
474/// Takes two hashmaps that contain node information. Identifies which pair of nodes from larger and smaller hashmaps maximizes the greedy cost function and returns the node from the `larger_subtree_nodes`.
475///
476/// # Arguments
477/// * `random_balance` - Allows for random selection of balanced node. If not None, identifies the best `usize` options and randomly selects one by weighted choice.
478/// * `larger_subtree_nodes` - A set of nodes used in comparison. Only the id from the larger subtree is returned.
479/// * `smaller_subtree_nodes` - A set of nodes used in comparison.
480/// * `objective_function` - Cost function that takes in two tensors and returns an f64 cost.
481fn find_rebalance_node<R>(
482    random_balance: &mut Option<(usize, R)>,
483    larger_subtree_nodes: &FxHashMap<usize, Tensor>,
484    smaller_subtree_nodes: &FxHashMap<usize, Tensor>,
485    objective_function: fn(&Tensor, &Tensor) -> f64,
486) -> (usize, f64)
487where
488    R: Rng,
489{
490    let node_comparison = larger_subtree_nodes
491        .iter()
492        .cartesian_product(smaller_subtree_nodes.iter())
493        .map(|((larger_node_id, larger_tensor), (_, smaller_tensor))| {
494            (
495                *larger_node_id,
496                objective_function(larger_tensor, smaller_tensor),
497            )
498        });
499    if let Some((options_considered, ref mut rng)) = random_balance {
500        let node_options = node_comparison
501            .sorted_unstable_by(|a, b| b.1.total_cmp(&a.1))
502            .take(*options_considered)
503            .collect_vec();
504        let max = node_options.first().unwrap().1;
505        // Initial division done here as sum of weights can cause overflow before normalization.
506        *node_options
507            .choose_weighted(rng, |node_option| node_option.1 / max)
508            .unwrap()
509    } else {
510        node_comparison.max_by(|a, b| a.1.total_cmp(&b.1)).unwrap()
511    }
512}
513
514/// Shifts `rebalance_node` from the larger subtree to the smaller subtree
515/// Updates partition tensor ids after subtrees are updated and a new contraction order is found.
516fn shift_node_between_subtrees(
517    contraction_tree: &mut ContractionTree,
518    rebalance_depth: usize,
519    larger_subtree_id: usize,
520    smaller_subtree_id: usize,
521    rebalanced_nodes: Vec<usize>,
522    tensor_network: &Tensor,
523) -> (usize, SimplePath, f64, f64, usize, SimplePath, f64, f64) {
524    // Obtain parents of the two subtrees that are being updated.
525    let larger_subtree_parent_id = contraction_tree
526        .node(larger_subtree_id)
527        .parent_id()
528        .unwrap();
529    let smaller_subtree_parent_id = contraction_tree
530        .node(smaller_subtree_id)
531        .parent_id()
532        .unwrap();
533
534    let mut larger_subtree_leaf_nodes = contraction_tree.leaf_ids(larger_subtree_id);
535    let mut smaller_subtree_leaf_nodes = contraction_tree.leaf_ids(smaller_subtree_id);
536
537    // Always check that a node can be moved over.
538    assert!(rebalanced_nodes
539        .iter()
540        .all(|node| !smaller_subtree_leaf_nodes.contains(node)));
541    assert!(rebalanced_nodes
542        .iter()
543        .all(|node| larger_subtree_leaf_nodes.contains(node)));
544
545    // Remove selected tensors from bigger subtree. Add it to the smaller subtree
546    larger_subtree_leaf_nodes.retain(|leaf| !rebalanced_nodes.contains(leaf));
547    smaller_subtree_leaf_nodes.extend(rebalanced_nodes);
548
549    // Run Greedy on the two updated subtrees
550    let (updated_larger_path, local_larger_path, larger_flop_cost, larger_mem_cost) =
551        subtree_contraction_path(&larger_subtree_leaf_nodes, contraction_tree, tensor_network);
552
553    let (updated_smaller_path, local_smaller_path, smaller_flop_cost, smaller_mem_cost) =
554        subtree_contraction_path(
555            &smaller_subtree_leaf_nodes,
556            contraction_tree,
557            tensor_network,
558        );
559
560    // Remove larger subtree and add new subtree, keep track of updated root id
561    contraction_tree.remove_subtree(larger_subtree_id);
562
563    let new_larger_subtree_id = if updated_larger_path.is_empty() {
564        // In this case, there is only one node left.
565        contraction_tree.nodes[&larger_subtree_leaf_nodes[0]]
566            .borrow_mut()
567            .set_parent(Rc::downgrade(
568                &contraction_tree.nodes[&larger_subtree_parent_id],
569            ));
570        contraction_tree.nodes[&larger_subtree_parent_id]
571            .borrow_mut()
572            .add_child(Rc::downgrade(
573                &contraction_tree.nodes[&larger_subtree_leaf_nodes[0]],
574            ));
575        larger_subtree_leaf_nodes[0]
576    } else {
577        contraction_tree.add_path_as_subtree(
578            &ContractionPath::simple(updated_larger_path),
579            larger_subtree_parent_id,
580            &larger_subtree_leaf_nodes,
581        )
582    };
583
584    // Remove smaller subtree
585    contraction_tree.remove_subtree(smaller_subtree_id);
586    // Add new subtree, keep track of updated root id
587    let new_smaller_subtree_id = contraction_tree.add_path_as_subtree(
588        &ContractionPath::simple(updated_smaller_path),
589        smaller_subtree_parent_id,
590        &smaller_subtree_leaf_nodes,
591    );
592
593    // Remove the old partition ids from the `ContractionTree` partitions member as the intermediate tensor id will be updated and then add the updated partition numbers.`
594    let partition = contraction_tree
595        .partitions
596        .get_mut(&rebalance_depth)
597        .unwrap();
598    partition.retain(|&e| e != smaller_subtree_id && e != larger_subtree_id);
599    partition.push(new_larger_subtree_id);
600    partition.push(new_smaller_subtree_id);
601
602    (
603        new_larger_subtree_id,
604        local_larger_path,
605        larger_flop_cost,
606        larger_mem_cost,
607        new_smaller_subtree_id,
608        local_smaller_path,
609        smaller_flop_cost,
610        smaller_mem_cost,
611    )
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    use std::rc::Rc;
619
620    use rand::{rngs::StdRng, SeedableRng};
621    use rustc_hash::FxHashMap;
622
623    use crate::contractionpath::contraction_tree::{
624        balancing::find_rebalance_node,
625        node::{child_node, parent_node},
626        ContractionTree,
627    };
628    use crate::path;
629    use crate::tensornetwork::tensor::Tensor;
630
631    fn setup_complex() -> (ContractionTree, Tensor) {
632        let bond_dims = FxHashMap::from_iter([
633            (0, 27),
634            (1, 18),
635            (2, 12),
636            (3, 15),
637            (4, 5),
638            (5, 3),
639            (6, 18),
640            (7, 22),
641            (8, 45),
642            (9, 65),
643            (10, 5),
644        ]);
645        let (tensor, contraction_path) = (
646            Tensor::new_composite(vec![
647                Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
648                Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
649                Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
650                Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
651                Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
652                Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
653            ]),
654            path![(1, 5), (0, 1), (3, 4), (2, 3), (0, 2)],
655        );
656        (
657            ContractionTree::from_contraction_path(&tensor, &contraction_path),
658            tensor,
659        )
660    }
661
662    #[test]
663    fn test_shift_leaf_node_between_subtrees() {
664        let (mut tree, tensor) = setup_complex();
665        tree.partitions.entry(1).or_insert_with(|| vec![9, 7]);
666        shift_node_between_subtrees(&mut tree, 1, 9, 7, vec![3], &tensor);
667
668        let ContractionTree { nodes, root, .. } = tree;
669
670        let node0 = child_node(0, vec![0]);
671        let node1 = child_node(1, vec![1]);
672        let node2 = child_node(2, vec![2]);
673        let node3 = child_node(3, vec![3]);
674        let node4 = child_node(4, vec![4]);
675        let node5 = child_node(5, vec![5]);
676
677        let node6 = parent_node(6, &node1, &node5);
678        let node7 = parent_node(7, &node6, &node0);
679        let node8 = parent_node(8, &node2, &node4);
680        let node9 = parent_node(9, &node7, &node3);
681        let node10 = parent_node(10, &node9, &node8);
682
683        let ref_root = Rc::clone(&node10);
684        let ref_nodes = [
685            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
686        ];
687
688        for (key, ref_node) in ref_nodes.iter().enumerate() {
689            let node = &nodes[&key];
690            assert_eq!(node, ref_node);
691        }
692        assert_eq!(root.upgrade().unwrap(), ref_root);
693    }
694
695    #[test]
696    fn test_shift_subtree_between_subtrees() {
697        let (mut tree, tensor) = setup_complex();
698        tree.partitions.entry(1).or_insert_with(|| vec![9, 7]);
699        shift_node_between_subtrees(&mut tree, 1, 9, 7, vec![2, 3], &tensor);
700
701        let ContractionTree { nodes, root, .. } = tree;
702
703        let node0 = child_node(0, vec![0]);
704        let node1 = child_node(1, vec![1]);
705        let node2 = child_node(2, vec![2]);
706        let node3 = child_node(3, vec![3]);
707        let node4 = child_node(4, vec![4]);
708        let node5 = child_node(5, vec![5]);
709
710        let node6 = parent_node(6, &node1, &node5);
711        let node7 = parent_node(7, &node2, &node3);
712        let node8 = parent_node(8, &node6, &node0);
713        let node9 = parent_node(9, &node8, &node7);
714        let node10 = parent_node(10, &node9, &node4);
715
716        let ref_root = Rc::clone(&node10);
717        let ref_nodes = [
718            node0, node1, node2, node3, node4, node5, node6, node7, node8, node9, node10,
719        ];
720
721        for (key, ref_node) in ref_nodes.iter().enumerate() {
722            let node = &nodes[&key];
723            assert_eq!(node, ref_node);
724        }
725
726        assert_eq!(root.upgrade().unwrap(), ref_root);
727    }
728
729    fn custom_weight_function(a: &Tensor, b: &Tensor) -> f64 {
730        (a & b).legs().len() as f64
731    }
732
733    #[test]
734    fn test_find_rebalance_node() {
735        let bond_dims =
736            FxHashMap::from_iter([(0, 2), (1, 1), (2, 3), (3, 5), (4, 3), (5, 8), (6, 7)]);
737        let larger_hash = FxHashMap::from_iter([
738            (0, Tensor::new_from_map(vec![0, 1, 2], &bond_dims)),
739            (1, Tensor::new_from_map(vec![1, 2, 3], &bond_dims)),
740            (2, Tensor::new_from_map(vec![3, 4, 5], &bond_dims)),
741        ]);
742
743        let smaller_hash =
744            FxHashMap::from_iter([(3, Tensor::new_from_map(vec![4, 5, 6], &bond_dims))]);
745
746        let ref_balanced_node = 2;
747        let (node_id, cost) = find_rebalance_node::<StdRng>(
748            &mut None,
749            &larger_hash,
750            &smaller_hash,
751            custom_weight_function,
752        );
753        assert_eq!(cost, 2.);
754        assert_eq!(node_id, ref_balanced_node);
755    }
756
757    #[test]
758    fn test_find_random_rebalance_node() {
759        let bond_dims =
760            FxHashMap::from_iter([(0, 2), (1, 1), (2, 3), (3, 5), (4, 3), (5, 8), (6, 7)]);
761        let larger_hash = FxHashMap::from_iter([
762            (0, Tensor::new_from_map(vec![0, 1, 2], &bond_dims)),
763            (1, Tensor::new_from_map(vec![1, 2, 6], &bond_dims)),
764            (2, Tensor::new_from_map(vec![3, 4, 5], &bond_dims)),
765        ]);
766
767        let smaller_hash =
768            FxHashMap::from_iter([(3, Tensor::new_from_map(vec![4, 5, 6], &bond_dims))]);
769
770        let ref_balanced_node = 1;
771        let (node_id, cost) = find_rebalance_node(
772            &mut Some((2, StdRng::seed_from_u64(1))),
773            &larger_hash,
774            &smaller_hash,
775            custom_weight_function,
776        );
777        assert_eq!(cost, 1.);
778        assert_eq!(node_id, ref_balanced_node);
779    }
780}