tnc/contractionpath/paths/
tree_tempering.rs1use itertools::Itertools;
2use rustc_hash::FxHashMap;
3use rustengra::{
4 cotengra_check, cotengra_tree_tempering, replace_to_ssa_path, tensor_legs_to_digit,
5};
6
7use crate::{
8 contractionpath::{
9 contraction_cost::contract_path_cost,
10 paths::{CostType, FindPath},
11 ssa_replace_ordering, ContractionPath,
12 },
13 tensornetwork::tensor::Tensor,
14};
15
16pub struct TreeTempering<'a> {
19 tensor: &'a Tensor,
20 numiter: Option<usize>,
21 best_flops: f64,
22 best_size: f64,
23 best_path: ContractionPath,
24 seed: Option<u64>,
25}
26
27impl<'a> TreeTempering<'a> {
28 pub fn new(
29 tensor: &'a Tensor,
30 seed: Option<u64>,
31 minimize: CostType,
32 numiter: Option<usize>,
33 ) -> Self {
34 assert!(cotengra_check().is_ok());
35 assert_eq!(
36 minimize,
37 CostType::Flops,
38 "Currently, only Flops is supported"
39 );
40 Self {
41 tensor,
42 numiter,
43 best_flops: f64::INFINITY,
44 best_size: f64::INFINITY,
45 best_path: ContractionPath::default(),
46 seed,
47 }
48 }
49}
50
51impl FindPath for TreeTempering<'_> {
52 fn find_path(&mut self) {
53 let inputs = self
55 .tensor
56 .tensors()
57 .iter()
58 .map(|tensor| tensor.legs().clone())
59 .collect_vec();
60 let outputs = self.tensor.external_tensor();
61 let size_dict = self.tensor.tensors().iter().map(Tensor::edges).fold(
62 FxHashMap::default(),
63 |mut acc, edges| {
64 acc.extend(edges);
65 acc
66 },
67 );
68
69 let (inputs, outputs, size_dict) =
70 tensor_legs_to_digit(&inputs, outputs.legs(), &size_dict);
71
72 let replace_path =
73 cotengra_tree_tempering(&inputs, outputs, self.numiter, size_dict, self.seed).unwrap();
74
75 let best_path = replace_to_ssa_path(replace_path, self.tensor.tensors().len());
76
77 self.best_path = ContractionPath::simple(best_path);
78
79 let (op_cost, mem_cost) =
80 contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
81
82 self.best_flops = op_cost;
83 self.best_size = mem_cost;
84 }
85
86 fn get_best_flops(&self) -> f64 {
87 self.best_flops
88 }
89
90 fn get_best_size(&self) -> f64 {
91 self.best_size
92 }
93
94 fn get_best_path(&self) -> &ContractionPath {
95 &self.best_path
96 }
97
98 fn get_best_replace_path(&self) -> ContractionPath {
99 ssa_replace_ordering(&self.best_path)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 use rustc_hash::FxHashMap;
108
109 use crate::{
110 contractionpath::paths::{CostType, FindPath},
111 path,
112 tensornetwork::tensor::Tensor,
113 };
114
115 fn setup_simple() -> Tensor {
116 let bond_dims =
117 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
118 Tensor::new_composite(vec![
119 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
120 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
121 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
122 ])
123 }
124
125 fn setup_complex() -> Tensor {
126 let bond_dims = FxHashMap::from_iter([
127 (0, 27),
128 (1, 18),
129 (2, 12),
130 (3, 15),
131 (4, 5),
132 (5, 3),
133 (6, 18),
134 (7, 22),
135 (8, 45),
136 (9, 65),
137 (10, 5),
138 (11, 17),
139 ]);
140 Tensor::new_composite(vec![
141 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
142 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
143 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
144 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
145 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
146 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
147 ])
148 }
149
150 #[test]
151 #[ignore = "flaky test due to a bug in cotengra"]
152 fn test_temper_tree_contract_order_simple() {
153 let tn = setup_simple();
154 let mut opt = TreeTempering::new(&tn, Some(8), CostType::Flops, Some(100));
155 opt.find_path();
156
157 assert_eq!(opt.best_flops, 600.);
158 assert_eq!(opt.best_size, 538.);
159 assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
160 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
161 }
162
163 #[test]
164 #[ignore = "flaky test due to a bug in cotengra"]
165 fn test_temper_tree_contract_order_complex() {
166 let tn = setup_complex();
167 let mut opt = TreeTempering::new(&tn, Some(8), CostType::Flops, Some(100));
168 opt.find_path();
169
170 assert_eq!(opt.best_flops, 332685.);
171 assert_eq!(opt.best_size, 89478.);
172 assert_eq!(opt.best_path, path![(1, 5), (0, 6), (2, 7), (3, 8), (4, 9)]);
173 assert_eq!(
174 opt.get_best_replace_path(),
175 path![(1, 5), (0, 1), (2, 0), (3, 2), (4, 3)]
176 );
177 }
178}