tnc/contractionpath/paths/
tree_reconfiguration.rs1use itertools::Itertools;
2use rustc_hash::FxHashMap;
3use rustengra::{
4 cotengra_check, cotengra_optimized_greedy, replace_to_ssa_path, tensor_legs_to_digit,
5};
6
7use crate::{
8 contractionpath::{
9 contraction_cost::contract_path_cost,
10 paths::{CostType, FindPath},
11 ssa_replace_ordering, ContractionPath,
12 },
13 tensornetwork::tensor::Tensor,
14};
15
16pub struct TreeReconfigure<'a> {
19 tensor: &'a Tensor,
20 subtree_size: usize,
21 best_flops: f64,
22 best_size: f64,
23 best_path: ContractionPath,
24}
25
26impl<'a> TreeReconfigure<'a> {
27 pub fn new(tensor: &'a Tensor, subtree_size: usize, minimize: CostType) -> Self {
31 assert!(cotengra_check().is_ok());
32 assert_eq!(
33 minimize,
34 CostType::Flops,
35 "Currently, only Flops is supported"
36 );
37 Self {
38 tensor,
39 subtree_size,
40 best_flops: f64::INFINITY,
41 best_size: f64::INFINITY,
42 best_path: ContractionPath::default(),
43 }
44 }
45}
46
47impl FindPath for TreeReconfigure<'_> {
48 fn find_path(&mut self) {
49 let inputs = self
51 .tensor
52 .tensors()
53 .iter()
54 .map(|tensor| tensor.legs().clone())
55 .collect_vec();
56 let outputs = self.tensor.external_tensor();
57 let size_dict = self.tensor.tensors().iter().map(Tensor::edges).fold(
58 FxHashMap::default(),
59 |mut acc, edges| {
60 acc.extend(edges);
61 acc
62 },
63 );
64
65 let (inputs, outputs, size_dict) =
66 tensor_legs_to_digit(&inputs, outputs.legs(), &size_dict);
67
68 let replace_path =
69 cotengra_optimized_greedy(&inputs, outputs, size_dict, self.subtree_size).unwrap();
70
71 let best_path = replace_to_ssa_path(replace_path, self.tensor.tensors().len());
72
73 self.best_path = ContractionPath::simple(best_path);
74
75 let (op_cost, mem_cost) =
76 contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
77
78 self.best_flops = op_cost;
79 self.best_size = mem_cost;
80 }
81
82 fn get_best_flops(&self) -> f64 {
83 self.best_flops
84 }
85
86 fn get_best_size(&self) -> f64 {
87 self.best_size
88 }
89
90 fn get_best_path(&self) -> &ContractionPath {
91 &self.best_path
92 }
93
94 fn get_best_replace_path(&self) -> ContractionPath {
95 ssa_replace_ordering(&self.best_path)
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 use rustc_hash::FxHashMap;
104
105 use crate::{
106 contractionpath::paths::{CostType, FindPath},
107 path,
108 tensornetwork::tensor::Tensor,
109 };
110
111 fn setup_simple() -> Tensor {
112 let bond_dims =
113 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
114 Tensor::new_composite(vec![
115 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
116 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
117 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
118 ])
119 }
120
121 fn setup_complex() -> Tensor {
122 let bond_dims = FxHashMap::from_iter([
123 (0, 27),
124 (1, 18),
125 (2, 12),
126 (3, 15),
127 (4, 5),
128 (5, 3),
129 (6, 18),
130 (7, 22),
131 (8, 45),
132 (9, 65),
133 (10, 5),
134 (11, 17),
135 ]);
136 Tensor::new_composite(vec![
137 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
138 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
139 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
140 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
141 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
142 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
143 ])
144 }
145
146 #[test]
147 fn test_tree_contract_order_simple() {
148 let tn = setup_simple();
149 let mut opt = TreeReconfigure::new(&tn, 8, CostType::Flops);
150 opt.find_path();
151
152 assert_eq!(opt.best_flops, 600.);
153 assert_eq!(opt.best_size, 538.);
154 assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
155 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
156 }
157
158 #[test]
159 fn test_tree_contract_order_complex() {
160 let tn = setup_complex();
161 let mut opt = TreeReconfigure::new(&tn, 8, CostType::Flops);
162 opt.find_path();
163
164 assert_eq!(opt.best_flops, 332685.);
165 assert_eq!(opt.best_size, 89478.);
166 assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
167 assert_eq!(
168 opt.get_best_replace_path(),
169 path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
170 );
171 }
172}