1use itertools::Itertools;
2use rustc_hash::FxHashMap;
3use rustengra::{cotengra_check, cotengra_sa_tree, replace_to_ssa_path, tensor_legs_to_digit};
4
5use crate::{
6 contractionpath::{
7 contraction_cost::contract_path_cost,
8 paths::{CostType, FindPath},
9 ssa_replace_ordering, ContractionPath,
10 },
11 tensornetwork::tensor::Tensor,
12};
13
14pub struct TreeAnnealing<'a> {
17 tensor: &'a Tensor,
18 temperature_steps: Option<usize>,
19 numiter: Option<usize>,
20 best_flops: f64,
21 best_size: f64,
22 best_path: ContractionPath,
23 seed: Option<u64>,
24}
25
26impl<'a> TreeAnnealing<'a> {
27 pub fn new(
28 tensor: &'a Tensor,
29 seed: Option<u64>,
30 minimize: CostType,
31 temperature_steps: Option<usize>,
32 numiter: Option<usize>,
33 ) -> Self {
34 assert!(cotengra_check().is_ok());
35 assert_eq!(
36 minimize,
37 CostType::Flops,
38 "Currently, only Flops is supported"
39 );
40 Self {
41 tensor,
42 temperature_steps,
43 numiter,
44 best_flops: f64::INFINITY,
45 best_size: f64::INFINITY,
46 best_path: ContractionPath::default(),
47 seed,
48 }
49 }
50}
51
52impl FindPath for TreeAnnealing<'_> {
53 fn find_path(&mut self) {
54 let inputs = self
56 .tensor
57 .tensors()
58 .iter()
59 .map(|tensor| tensor.legs().clone())
60 .collect_vec();
61 let outputs = self.tensor.external_tensor();
62 let size_dict = self.tensor.tensors().iter().map(Tensor::edges).fold(
63 FxHashMap::default(),
64 |mut acc, edges| {
65 acc.extend(edges);
66 acc
67 },
68 );
69
70 let (inputs, outputs, size_dict) =
71 tensor_legs_to_digit(&inputs, outputs.legs(), &size_dict);
72
73 let replace_path = cotengra_sa_tree(
74 &inputs,
75 outputs,
76 self.temperature_steps,
77 self.numiter,
78 size_dict,
79 self.seed,
80 )
81 .unwrap();
82
83 let best_path = replace_to_ssa_path(replace_path, self.tensor.tensors().len());
84
85 self.best_path = ContractionPath::simple(best_path);
86
87 let (op_cost, mem_cost) =
88 contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
89
90 self.best_flops = op_cost;
91 self.best_size = mem_cost;
92 }
93
94 fn get_best_flops(&self) -> f64 {
95 self.best_flops
96 }
97
98 fn get_best_size(&self) -> f64 {
99 self.best_size
100 }
101
102 fn get_best_path(&self) -> &ContractionPath {
103 &self.best_path
104 }
105
106 fn get_best_replace_path(&self) -> ContractionPath {
107 ssa_replace_ordering(&self.best_path)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 use rustc_hash::FxHashMap;
116
117 use crate::{
118 contractionpath::paths::{CostType, FindPath},
119 path,
120 tensornetwork::tensor::Tensor,
121 };
122
123 fn setup_simple() -> Tensor {
124 let bond_dims =
125 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
126 Tensor::new_composite(vec![
127 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
128 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
129 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
130 ])
131 }
132
133 fn setup_complex() -> Tensor {
134 let bond_dims = FxHashMap::from_iter([
135 (0, 27),
136 (1, 18),
137 (2, 12),
138 (3, 15),
139 (4, 5),
140 (5, 3),
141 (6, 18),
142 (7, 22),
143 (8, 45),
144 (9, 65),
145 (10, 5),
146 (11, 17),
147 ]);
148 Tensor::new_composite(vec![
149 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
150 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
151 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
152 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
153 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
154 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
155 ])
156 }
157
158 #[test]
159 fn test_anneal_tree_contract_order_simple() {
160 let tn = setup_simple();
161 let mut opt = TreeAnnealing::new(&tn, Some(8), CostType::Flops, Some(100), Some(50));
162 opt.find_path();
163
164 assert_eq!(opt.best_flops, 600.);
165 assert_eq!(opt.best_size, 538.);
166 assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
167 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
168 }
169
170 #[test]
171 fn test_anneal_tree_contract_order_complex() {
172 let tn = setup_complex();
173 let mut opt = TreeAnnealing::new(&tn, Some(8), CostType::Flops, Some(100), Some(50));
174 opt.find_path();
175
176 assert_eq!(opt.best_flops, 332685.);
177 assert_eq!(opt.best_size, 89478.);
178 assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
179 assert_eq!(
180 opt.get_best_replace_path(),
181 path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
182 );
183 }
184}