tnc/contractionpath/contraction_tree/balancing/
balancing_schemes.rs

1use rand::Rng;
2use rustc_hash::FxHashMap;
3
4use crate::contractionpath::contraction_tree::balancing::{find_rebalance_node, PartitionData};
5use crate::contractionpath::contraction_tree::{
6    populate_leaf_node_tensor_map, populate_subtree_tensor_map, ContractionTree,
7};
8use crate::tensornetwork::tensor::Tensor;
9
10/// The scheme used for greedy balancing of partitions.
11#[derive(Debug, Clone, Copy)]
12pub enum BalancingScheme {
13    /// Moves a tensor from the slowest subtree to the fastest subtree each time.
14    BestWorst,
15
16    /// Identifies the tensor in the slowest subtree and passes it to the subtree
17    /// with largest memory reduction.
18    Tensor,
19
20    /// Identifies the tensor in the slowest subtree and passes it to the subtree
21    /// with largest memory reduction. Then identifies the tensor with the largest
22    /// memory reduction when passed to the fastest subtree. Both slowest and fastest
23    /// subtrees are updated.
24    Tensors,
25
26    /// Identifies the tensor in the slowest subtree and passes it to the subtree
27    /// with largest memory reduction for odd iterations or the tensor with the largest
28    /// memory reduction when passed to the fastest subtree for even iterations.
29    AlternatingTensors,
30
31    /// Identifies the intermediate tensor in the slowest subtree and passes it to
32    /// the subtree with largest memory reduction. Then identifies the intermediate
33    /// tensor with the largest memory reduction when passed to the fastest subtree.
34    /// Both slowest and fastest subtrees are updated.
35    IntermediateTensors {
36        /// The `height` up the contraction tree we look when passing intermediate
37        /// tensors between partitions. A value of `Some(1)` allows intermediate tensors
38        /// that are a product of at most 1 contraction process. Using the value of
39        /// `Some(0)` is then equivalent to the `Tensors` method. Setting it to `None`
40        /// imposes no height limit.
41        height_limit: Option<usize>,
42    },
43
44    /// Identifies the intermediate tensor in the slowest subtree and passes it to
45    /// the subtree with largest memory reduction for odd iterations. Identifies the intermediate
46    /// tensor with the largest memory reduction when passed to the fastest subtree for
47    /// odd iterations.
48    AlternatingIntermediateTensors {
49        /// The `height` up the contraction tree we look when passing intermediate
50        /// tensors between partitions. A value of `Some(1)` allows intermediate tensors
51        /// that are a product of at most 1 contraction process. Using the value of
52        /// `Some(0)` is then equivalent to the `Tensors` method. Setting it to `None`
53        /// imposes no height limit.
54        height_limit: Option<usize>,
55    },
56
57    /// Identifies the intermediate tensor in the slowest subtree and passes it to
58    /// the subtree with largest memory reduction for odd iterations. Identifies the intermediate
59    /// tensor with the largest memory reduction when passed to the fastest subtree for
60    /// odd iterations.
61    AlternatingTreeTensors {
62        /// The `height` up the contraction tree we look when passing intermediate
63        /// tensors between partitions. A value of `1` allows intermediate tensors
64        /// that are a product of at most 1 contraction process. Using the value of
65        /// `0` is then equivalent to the `Tensors` method.
66        height_limit: usize,
67    },
68}
69
70/// Shift of tensors between partitions.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub(super) struct Shift {
73    /// Id of the source partition.
74    pub from_subtree_id: usize,
75    /// Id of the destination partition.
76    pub to_subtree_id: usize,
77    /// Ids of the leaf nodes that are moved.
78    pub moved_leaf_ids: Vec<usize>,
79}
80
81/// Balancing scheme that moves a tensor from the slowest subtree to the fastest subtree each time.
82/// Chosen tensor maximizes the `objective_function`, which is typically memory reduction.
83pub(super) fn best_worst<R>(
84    partition_data: &[PartitionData],
85    contraction_tree: &ContractionTree,
86    random_balance: &mut Option<(usize, R)>,
87    objective_function: fn(&Tensor, &Tensor) -> f64,
88    tensor: &Tensor,
89) -> Vec<Shift>
90where
91    R: Rng,
92{
93    // Obtain most expensive and cheapest partitions
94    let larger_subtree_id = partition_data.last().unwrap().id;
95    let smaller_subtree_id = partition_data.first().unwrap().id;
96
97    let larger_subtree_leaf_nodes =
98        populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
99
100    let smaller_subtree_leaf_nodes =
101        populate_leaf_node_tensor_map(contraction_tree, smaller_subtree_id, tensor);
102
103    let (rebalanced_node, _) = find_rebalance_node(
104        random_balance,
105        &larger_subtree_leaf_nodes,
106        &smaller_subtree_leaf_nodes,
107        objective_function,
108    );
109    let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
110    vec![Shift {
111        from_subtree_id: larger_subtree_id,
112        to_subtree_id: smaller_subtree_id,
113        moved_leaf_ids: rebalanced_leaf_ids,
114    }]
115}
116
117/// Balancing scheme that identifies the tensor in the slowest subtree and passes it to the subtree with largest memory reduction.
118/// Chosen tensor maximizes the `objective_function`, which is typically memory reduction.
119pub(super) fn best_tensor<R>(
120    partition_data: &[PartitionData],
121    contraction_tree: &ContractionTree,
122    random_balance: &mut Option<(usize, R)>,
123    objective_function: fn(&Tensor, &Tensor) -> f64,
124    tensor: &Tensor,
125) -> Vec<Shift>
126where
127    R: Rng,
128{
129    // Obtain most expensive partitions
130    let larger_subtree_id = partition_data.last().unwrap().id;
131
132    let larger_subtree_leaf_nodes =
133        populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
134    // Find the subtree shift that results in the largest memory savings
135    let (smaller_subtree_id, rebalanced_node, _) = partition_data
136        .iter()
137        .take(partition_data.len() - 1)
138        .map(|smaller| {
139            let smaller_subtree_nodes =
140                populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
141            let (rebalanced_node, objective) = find_rebalance_node(
142                random_balance,
143                &larger_subtree_leaf_nodes,
144                &smaller_subtree_nodes,
145                objective_function,
146            );
147            (smaller.id, rebalanced_node, objective)
148        })
149        .max_by(|a, b| a.2.total_cmp(&b.2))
150        .unwrap();
151
152    let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
153    vec![Shift {
154        from_subtree_id: larger_subtree_id,
155        to_subtree_id: smaller_subtree_id,
156        moved_leaf_ids: rebalanced_leaf_ids,
157    }]
158}
159
160/// Balancing scheme that identifies the tensor in the slowest subtree and passes it to the subtree with largest memory reduction.
161/// Then identifies the tensor with the largest memory reduction when passed to the fastest subtree. Both slowest and fastest subtrees are updated.
162pub(super) fn best_tensors<R>(
163    partition_data: &[PartitionData],
164    contraction_tree: &ContractionTree,
165    random_balance: &mut Option<(usize, R)>,
166    objective_function: fn(&Tensor, &Tensor) -> f64,
167    tensor: &Tensor,
168) -> Vec<Shift>
169where
170    R: Rng,
171{
172    // Obtain most expensive and cheapest partitions
173    let larger_subtree_id = partition_data.last().unwrap().id;
174
175    let larger_subtree_leaf_nodes =
176        populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
177
178    // Find the subtree shift that results in the largest memory savings
179    let (smaller_subtree_id, rebalanced_node, _) = partition_data
180        .iter()
181        .take(partition_data.len() - 1)
182        .map(|smaller| {
183            let smaller_subtree_nodes =
184                populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
185            let (rebalanced_node, objective) = find_rebalance_node(
186                random_balance,
187                &larger_subtree_leaf_nodes,
188                &smaller_subtree_nodes,
189                objective_function,
190            );
191            (smaller.id, rebalanced_node, objective)
192        })
193        .max_by(|a, b| a.2.total_cmp(&b.2))
194        .unwrap();
195    let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
196
197    let mut shifts = Vec::with_capacity(2);
198    shifts.push(Shift {
199        from_subtree_id: larger_subtree_id,
200        to_subtree_id: smaller_subtree_id,
201        moved_leaf_ids: rebalanced_leaf_ids,
202    });
203
204    let smaller_subtree_id = partition_data.first().unwrap().id;
205
206    let smaller_subtree_nodes =
207        populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
208
209    let (larger_subtree_id, rebalanced_node, _) = partition_data
210        .iter()
211        .skip(1)
212        .take(partition_data.len() - 2)
213        .map(|larger| {
214            let larger_subtree_nodes =
215                populate_leaf_node_tensor_map(contraction_tree, larger.id, tensor);
216            let (rebalanced_node, objective) = find_rebalance_node(
217                random_balance,
218                &larger_subtree_nodes,
219                &smaller_subtree_nodes,
220                objective_function,
221            );
222
223            (larger.id, rebalanced_node, objective)
224        })
225        .max_by(|(_, _, obj_a), (_, _, obj_b)| obj_a.total_cmp(obj_b))
226        .unwrap();
227
228    let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
229    shifts.push(Shift {
230        from_subtree_id: larger_subtree_id,
231        to_subtree_id: smaller_subtree_id,
232        moved_leaf_ids: rebalanced_leaf_ids,
233    });
234    shifts
235}
236
237/// Balancing scheme that identifies the tensor in the slowest subtree and passes it to the subtree with largest memory reduction.
238pub(super) fn tensors_odd<R>(
239    partition_data: &[PartitionData],
240    contraction_tree: &ContractionTree,
241    random_balance: &mut Option<(usize, R)>,
242    objective_function: fn(&Tensor, &Tensor) -> f64,
243    tensor: &Tensor,
244) -> Vec<Shift>
245where
246    R: Rng,
247{
248    // Obtain most expensive partition
249    let larger_subtree_id = partition_data.last().unwrap().id;
250
251    let larger_subtree_leaf_nodes =
252        populate_leaf_node_tensor_map(contraction_tree, larger_subtree_id, tensor);
253
254    // Find the subtree shift that results in the largest memory savings
255    let (smaller_subtree_id, rebalanced_leaf_node, _) = partition_data
256        .iter()
257        .take(partition_data.len() - 1)
258        .map(|smaller| {
259            let smaller_subtree_nodes = FxHashMap::from_iter([(0, smaller.local_tensor.clone())]);
260            let (rebalanced_node, objective) = find_rebalance_node(
261                random_balance,
262                &larger_subtree_leaf_nodes,
263                &smaller_subtree_nodes,
264                objective_function,
265            );
266            (smaller.id, rebalanced_node, objective)
267        })
268        .max_by(|a, b| a.2.total_cmp(&b.2))
269        .unwrap();
270
271    let rebalanced_leaf_id = contraction_tree.leaf_ids(rebalanced_leaf_node);
272    vec![Shift {
273        from_subtree_id: larger_subtree_id,
274        to_subtree_id: smaller_subtree_id,
275        moved_leaf_ids: rebalanced_leaf_id,
276    }]
277}
278
279/// Balancing scheme that identifies the tensor with the largest memory reduction when passed to the fastest subtree.
280pub(super) fn tensors_even<R>(
281    partition_data: &[PartitionData],
282    contraction_tree: &ContractionTree,
283    random_balance: &mut Option<(usize, R)>,
284    objective_function: fn(&Tensor, &Tensor) -> f64,
285    tensor: &Tensor,
286) -> Vec<Shift>
287where
288    R: Rng,
289{
290    let smaller_subtree_id = partition_data.first().unwrap().id;
291
292    let smaller_subtree_nodes =
293        FxHashMap::from_iter([(0, partition_data.first().unwrap().local_tensor.clone())]);
294
295    let (larger_subtree_id, rebalanced_leaf_node, _) = partition_data
296        .iter()
297        .skip(1)
298        .map(|larger| {
299            let larger_subtree_leaf_nodes =
300                populate_leaf_node_tensor_map(contraction_tree, larger.id, tensor);
301            let (rebalanced_node, objective) = find_rebalance_node(
302                random_balance,
303                &larger_subtree_leaf_nodes,
304                &smaller_subtree_nodes,
305                objective_function,
306            );
307
308            (larger.id, rebalanced_node, objective)
309        })
310        .max_by(|a, b| a.2.total_cmp(&b.2))
311        .unwrap();
312
313    let rebalanced_leaf_id = contraction_tree.leaf_ids(rebalanced_leaf_node);
314    vec![Shift {
315        from_subtree_id: larger_subtree_id,
316        to_subtree_id: smaller_subtree_id,
317        moved_leaf_ids: rebalanced_leaf_id,
318    }]
319}
320
321/// Balancing scheme that identifies the tensor in the slowest subtree and passes it to the subtree with largest memory reduction.
322/// Then identifies the tensor with the largest memory reduction when passed to the fastest subtree. Both slowest and fastest subtrees are updated.
323pub(super) fn best_intermediate_tensors<R>(
324    partition_data: &[PartitionData],
325    contraction_tree: &ContractionTree,
326    random_balance: &mut Option<(usize, R)>,
327    objective_function: fn(&Tensor, &Tensor) -> f64,
328    tensor: &Tensor,
329    height_limit: Option<usize>,
330) -> Vec<Shift>
331where
332    R: Rng,
333{
334    // Obtain most expensive partition
335    let larger_subtree_id = partition_data.last().unwrap().id;
336
337    // Obtain all intermediate nodes up to height `height_limit` in larger subtree
338    let mut larger_subtree_nodes =
339        populate_subtree_tensor_map(contraction_tree, larger_subtree_id, tensor, height_limit);
340    larger_subtree_nodes.remove(&larger_subtree_id);
341
342    // Find the subtree shift that results in the largest memory savings
343    let (smaller_subtree_id, first_rebalanced_node, _) = partition_data
344        .iter()
345        .take(partition_data.len() - 1)
346        .map(|smaller| {
347            let smaller_subtree_nodes =
348                populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
349            let (rebalanced_node, objective) = find_rebalance_node(
350                random_balance,
351                &larger_subtree_nodes,
352                &smaller_subtree_nodes,
353                objective_function,
354            );
355            (smaller.id, rebalanced_node, objective)
356        })
357        .max_by(|a, b| a.2.total_cmp(&b.2))
358        .unwrap();
359    let rebalanced_leaf_ids = contraction_tree.leaf_ids(first_rebalanced_node);
360
361    let mut shifts = Vec::with_capacity(2);
362    shifts.push(Shift {
363        from_subtree_id: larger_subtree_id,
364        to_subtree_id: smaller_subtree_id,
365        moved_leaf_ids: rebalanced_leaf_ids,
366    });
367
368    let smaller_subtree_id = partition_data.first().unwrap().id;
369
370    let smaller_subtree_nodes =
371        populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
372
373    let (larger_subtree_id, second_rebalanced_node, _) = partition_data
374        .iter()
375        .skip(1)
376        .take(partition_data.len() - 2)
377        .map(|larger| {
378            let mut larger_subtree_nodes =
379                populate_subtree_tensor_map(contraction_tree, larger.id, tensor, height_limit);
380            larger_subtree_nodes.remove(&larger.id);
381            let (rebalanced_node, objective) = find_rebalance_node(
382                random_balance,
383                &larger_subtree_nodes,
384                &smaller_subtree_nodes,
385                objective_function,
386            );
387
388            (larger.id, rebalanced_node, objective)
389        })
390        .max_by(|a, b| a.2.total_cmp(&b.2))
391        .unwrap();
392
393    let rebalanced_leaf_ids = contraction_tree.leaf_ids(second_rebalanced_node);
394    shifts.push(Shift {
395        from_subtree_id: larger_subtree_id,
396        to_subtree_id: smaller_subtree_id,
397        moved_leaf_ids: rebalanced_leaf_ids,
398    });
399    shifts
400}
401
402/// Balancing scheme that identifies the tensor in the slowest subtree and passes it to the subtree with largest memory reduction.
403pub(super) fn intermediate_tensors_odd<R>(
404    partition_data: &[PartitionData],
405    contraction_tree: &ContractionTree,
406    random_balance: &mut Option<(usize, R)>,
407    objective_function: fn(&Tensor, &Tensor) -> f64,
408    tensor: &Tensor,
409    height_limit: Option<usize>,
410) -> Vec<Shift>
411where
412    R: Rng,
413{
414    // Obtain most expensive partition
415    let larger_subtree_id = partition_data.last().unwrap().id;
416
417    // Obtain all intermediate nodes up to height `height_limit` in larger subtree
418    let mut larger_subtree_nodes =
419        populate_subtree_tensor_map(contraction_tree, larger_subtree_id, tensor, height_limit);
420    larger_subtree_nodes.remove(&larger_subtree_id);
421
422    // Find the subtree shift that results in the largest memory savings
423    let (smaller_subtree_id, rebalanced_node, _) = partition_data
424        .iter()
425        .take(partition_data.len() - 1)
426        .map(|smaller| {
427            let smaller_subtree_nodes =
428                populate_subtree_tensor_map(contraction_tree, smaller.id, tensor, None);
429            let (rebalanced_node, objective) = find_rebalance_node(
430                random_balance,
431                &larger_subtree_nodes,
432                &smaller_subtree_nodes,
433                objective_function,
434            );
435            (smaller.id, rebalanced_node, objective)
436        })
437        .max_by(|a, b| a.2.total_cmp(&b.2))
438        .unwrap();
439
440    let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
441    vec![Shift {
442        from_subtree_id: larger_subtree_id,
443        to_subtree_id: smaller_subtree_id,
444        moved_leaf_ids: rebalanced_leaf_ids,
445    }]
446}
447
448/// Balancing scheme that identifies the intermediate tensor with the largest memory reduction when passed to the fastest subtree.
449pub(super) fn intermediate_tensors_even<R>(
450    partition_data: &[PartitionData],
451    contraction_tree: &ContractionTree,
452    random_balance: &mut Option<(usize, R)>,
453    objective_function: fn(&Tensor, &Tensor) -> f64,
454    tensor: &Tensor,
455    height_limit: Option<usize>,
456) -> Vec<Shift>
457where
458    R: Rng,
459{
460    let smaller_subtree_id = partition_data.first().unwrap().id;
461
462    let smaller_subtree_nodes =
463        populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
464
465    let (larger_subtree_id, rebalanced_node, _) = partition_data
466        .iter()
467        .skip(1)
468        .filter_map(|larger| {
469            let mut larger_subtree_nodes =
470                populate_subtree_tensor_map(contraction_tree, larger.id, tensor, height_limit);
471            if larger_subtree_nodes.len() == 1 {
472                return None;
473            }
474            larger_subtree_nodes.remove(&larger.id);
475            let (rebalanced_node, objective) = find_rebalance_node(
476                random_balance,
477                &larger_subtree_nodes,
478                &smaller_subtree_nodes,
479                objective_function,
480            );
481
482            Some((larger.id, rebalanced_node, objective))
483        })
484        .max_by(|a, b| a.2.total_cmp(&b.2))
485        .unwrap();
486
487    let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
488    vec![Shift {
489        from_subtree_id: larger_subtree_id,
490        to_subtree_id: smaller_subtree_id,
491        moved_leaf_ids: rebalanced_leaf_ids,
492    }]
493}
494
495/// Balancing scheme that identifies the tensor in the slowest subtree and passes it to the subtree with largest memory reduction.
496pub(super) fn tree_tensors_odd(
497    partition_data: &[PartitionData],
498    contraction_tree: &ContractionTree,
499    objective_function: fn(&Tensor, &Tensor) -> f64,
500    tensor: &Tensor,
501    height_limit: usize,
502) -> Vec<Shift> {
503    // Obtain most expensive partition
504    let larger_subtree_id = partition_data.last().unwrap().id;
505
506    // Obtain all intermediate nodes up to height `height_limit` in larger subtree
507    let mut larger_subtree_nodes = populate_subtree_tensor_map(
508        contraction_tree,
509        larger_subtree_id,
510        tensor,
511        Some(height_limit),
512    );
513    larger_subtree_nodes.remove(&larger_subtree_id);
514
515    // Find the subtree shift that results in the largest memory savings
516    let (smaller_subtree_id, rebalanced_node, _) = partition_data
517        .iter()
518        .take(partition_data.len() - 1)
519        .map(|smaller| {
520            let PartitionData {
521                local_tensor, id, ..
522            } = smaller;
523            let mut objective = 0.;
524            let mut rebalanced_node = None;
525            for (node_id, node) in &larger_subtree_nodes {
526                let new_obj = objective_function(node, local_tensor);
527                if new_obj > objective {
528                    objective = new_obj;
529                    rebalanced_node = Some(*node_id);
530                }
531            }
532            (*id, rebalanced_node, objective)
533        })
534        .max_by(|a, b| a.2.total_cmp(&b.2))
535        .unwrap();
536    if let Some(rebalanced_node) = rebalanced_node {
537        let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
538        vec![Shift {
539            from_subtree_id: larger_subtree_id,
540            to_subtree_id: smaller_subtree_id,
541            moved_leaf_ids: rebalanced_leaf_ids,
542        }]
543    } else {
544        Vec::new()
545    }
546}
547
548/// Balancing scheme that identifies the intermediate tensor with the largest memory reduction when passed to the fastest subtree.
549pub(super) fn tree_tensors_even(
550    partition_data: &[PartitionData],
551    contraction_tree: &ContractionTree,
552    objective_function: fn(&Tensor, &Tensor) -> f64,
553    tensor: &Tensor,
554    height_limit: usize,
555) -> Vec<Shift> {
556    let smaller_subtree_id = partition_data.first().unwrap().id;
557
558    // let smaller_subtree_nodes =
559    //     populate_subtree_tensor_map(contraction_tree, smaller_subtree_id, tensor, None);
560    let PartitionData {
561        local_tensor: smaller_tensor,
562        ..
563    } = partition_data.first().unwrap();
564
565    let (larger_subtree_id, rebalanced_node, _) = partition_data
566        .iter()
567        .skip(1)
568        .filter_map(|larger| {
569            let mut larger_subtree_nodes = populate_subtree_tensor_map(
570                contraction_tree,
571                larger.id,
572                tensor,
573                Some(height_limit),
574            );
575            if larger_subtree_nodes.len() == 1 {
576                return None;
577            }
578            larger_subtree_nodes.remove(&larger.id);
579            let mut objective = 0.;
580            let mut rebalanced_node = None;
581            for (node_id, node) in &larger_subtree_nodes {
582                let new_obj = objective_function(node, smaller_tensor);
583                if new_obj > objective {
584                    objective = new_obj;
585                    rebalanced_node = Some(*node_id);
586                }
587            }
588
589            Some((larger.id, rebalanced_node, objective))
590        })
591        .max_by(|a, b| a.2.total_cmp(&b.2))
592        .unwrap();
593
594    if let Some(rebalanced_node) = rebalanced_node {
595        let rebalanced_leaf_ids = contraction_tree.leaf_ids(rebalanced_node);
596        vec![Shift {
597            from_subtree_id: larger_subtree_id,
598            to_subtree_id: smaller_subtree_id,
599            moved_leaf_ids: rebalanced_leaf_ids,
600        }]
601    } else {
602        Vec::new()
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    use rand::rngs::StdRng;
611    use rustc_hash::FxHashMap;
612
613    use crate::{
614        contractionpath::contraction_tree::{balancing::PartitionData, ContractionTree},
615        path,
616        tensornetwork::tensor::Tensor,
617    };
618
619    fn setup_simple_partition_data() -> Vec<PartitionData> {
620        let bond_dims = FxHashMap::from_iter([
621            (0, 2),
622            (1, 2),
623            (2, 2),
624            (3, 2),
625            (4, 2),
626            (5, 2),
627            (6, 2),
628            (7, 2),
629            (8, 2),
630            (9, 2),
631            (10, 2),
632        ]);
633        vec![
634            PartitionData {
635                id: 2,
636                flop_cost: 1.,
637                mem_cost: 0.,
638                contraction: Default::default(),
639                local_tensor: Tensor::new_from_map(vec![7, 9, 10], &bond_dims),
640            },
641            PartitionData {
642                id: 7,
643                flop_cost: 2.,
644                mem_cost: 0.,
645                contraction: Default::default(),
646                local_tensor: Tensor::new_from_map(vec![0, 1, 5, 7], &bond_dims),
647            },
648            PartitionData {
649                id: 14,
650                flop_cost: 3.,
651                mem_cost: 0.,
652                contraction: Default::default(),
653                local_tensor: Tensor::new_from_map(vec![0, 1, 2, 5, 10], &bond_dims),
654            },
655        ]
656    }
657
658    /// Tensor ids in contraction tree included in variable name for easy tracking
659    fn setup_simple() -> (ContractionTree, Tensor) {
660        let bond_dims = FxHashMap::from_iter([
661            (0, 2),
662            (1, 2),
663            (2, 2),
664            (3, 2),
665            (4, 2),
666            (5, 2),
667            (6, 2),
668            (7, 2),
669            (8, 2),
670            (9, 2),
671            (10, 2),
672        ]);
673
674        let tensor0 = Tensor::new_from_map(vec![7, 8], &bond_dims);
675        let tensor1 = Tensor::new_from_map(vec![8, 9, 10], &bond_dims);
676
677        let tensor3 = Tensor::new_from_map(vec![0, 6], &bond_dims);
678        let tensor4 = Tensor::new_from_map(vec![1, 6], &bond_dims);
679        let tensor5 = Tensor::new_from_map(vec![5, 7], &bond_dims);
680
681        let tensor8 = Tensor::new_from_map(vec![0, 1], &bond_dims);
682        let tensor9 = Tensor::new_from_map(vec![2, 3], &bond_dims);
683        let tensor10 = Tensor::new_from_map(vec![3, 4], &bond_dims);
684        let tensor11 = Tensor::new_from_map(vec![4, 5, 10], &bond_dims);
685
686        let intermediate_tensor2 = Tensor::new_composite(vec![tensor0, tensor1]);
687
688        let intermediate_tensor7 = Tensor::new_composite(vec![tensor3, tensor4, tensor5]);
689
690        let intermediate_tensor14 =
691            Tensor::new_composite(vec![tensor8, tensor9, tensor10, tensor11]);
692
693        let tensor15 = Tensor::new_composite(vec![
694            intermediate_tensor2,
695            intermediate_tensor7,
696            intermediate_tensor14,
697        ]);
698
699        let contraction_path = path![
700            {
701            (0, [(0, 1)]),
702            (1, [(0, 1), (0, 2)]),
703            (2, [(0, 3), (2, 1), (0, 2)]),
704            },
705            (0, 1),
706            (0, 2)
707        ];
708
709        (
710            ContractionTree::from_contraction_path(&tensor15, &contraction_path),
711            tensor15,
712        )
713    }
714
715    fn custom_cost_function(a: &Tensor, b: &Tensor) -> f64 {
716        (a & b).legs().len() as f64
717    }
718
719    #[test]
720    fn test_best_worst_balancing() {
721        let partition_data = setup_simple_partition_data();
722        let (contraction_tree, tensor) = setup_simple();
723
724        let output = best_worst::<StdRng>(
725            &partition_data,
726            &contraction_tree,
727            &mut None,
728            custom_cost_function,
729            &tensor,
730        );
731
732        let ref_output = vec![Shift {
733            from_subtree_id: 14,
734            to_subtree_id: 2,
735            moved_leaf_ids: vec![11],
736        }];
737        assert_eq!(output, ref_output);
738    }
739
740    #[test]
741    fn test_tensor_balancing() {
742        let partition_data = setup_simple_partition_data();
743        let (contraction_tree, tensor) = setup_simple();
744
745        let output = best_tensor::<StdRng>(
746            &partition_data,
747            &contraction_tree,
748            &mut None,
749            custom_cost_function,
750            &tensor,
751        );
752
753        let ref_output = vec![Shift {
754            from_subtree_id: 14,
755            to_subtree_id: 7,
756            moved_leaf_ids: vec![8],
757        }];
758        assert_eq!(output, ref_output);
759    }
760
761    #[test]
762    fn test_tensors_balancing() {
763        let partition_data = setup_simple_partition_data();
764        let (contraction_tree, tensor) = setup_simple();
765
766        let output = best_tensors::<StdRng>(
767            &partition_data,
768            &contraction_tree,
769            &mut None,
770            custom_cost_function,
771            &tensor,
772        );
773
774        let ref_output = vec![
775            Shift {
776                from_subtree_id: 14,
777                to_subtree_id: 7,
778                moved_leaf_ids: vec![8],
779            },
780            Shift {
781                from_subtree_id: 7,
782                to_subtree_id: 2,
783                moved_leaf_ids: vec![5],
784            },
785        ];
786        assert_eq!(output, ref_output);
787    }
788
789    #[test]
790    fn test_alternating_tensors_balancing_odd() {
791        let partition_data = setup_simple_partition_data();
792        let (contraction_tree, tensor) = setup_simple();
793
794        let output = tensors_odd::<StdRng>(
795            &partition_data,
796            &contraction_tree,
797            &mut None,
798            custom_cost_function,
799            &tensor,
800        );
801
802        // Shift tensor11 = Tensor::new(vec![4, 5, 10]);
803        // Max overlap is tensor1 = Tensor::new(vec![8, 9, 10]);
804        let ref_output = vec![Shift {
805            from_subtree_id: 14,
806            to_subtree_id: 7,
807            moved_leaf_ids: vec![8],
808        }];
809        assert_eq!(output, ref_output);
810    }
811
812    #[test]
813    fn test_alternating_tensors_balancing_even() {
814        let partition_data = setup_simple_partition_data();
815        let (contraction_tree, tensor) = setup_simple();
816
817        let output = tensors_even::<StdRng>(
818            &partition_data,
819            &contraction_tree,
820            &mut None,
821            custom_cost_function,
822            &tensor,
823        );
824        // Shift tensor8 = Tensor::new(vec![0, 1]);
825        // Max overlap is tensor3 = Tensor::new(vec![0, 6]);
826        let ref_output = vec![Shift {
827            from_subtree_id: 14,
828            to_subtree_id: 2,
829            moved_leaf_ids: vec![11],
830        }];
831        assert_eq!(output, ref_output);
832    }
833
834    #[test]
835    fn test_intermediate_tensors_balancing() {
836        let partition_data = setup_simple_partition_data();
837        let (contraction_tree, tensor) = setup_simple();
838
839        let output = best_intermediate_tensors::<StdRng>(
840            &partition_data,
841            &contraction_tree,
842            &mut None,
843            custom_cost_function,
844            &tensor,
845            Some(1),
846        );
847
848        let ref_output = vec![
849            Shift {
850                from_subtree_id: 14,
851                to_subtree_id: 7,
852                moved_leaf_ids: vec![8, 11],
853            },
854            Shift {
855                from_subtree_id: 7,
856                to_subtree_id: 2,
857                moved_leaf_ids: vec![5],
858            },
859        ];
860        assert_eq!(output, ref_output);
861    }
862}