1use std::iter::zip;
2
3use cotengrust::{optimize_greedy_rust, optimize_optimal_rust, optimize_random_greedy_rust};
4use itertools::Itertools;
5use rustc_hash::FxHashMap;
6
7use crate::contractionpath::contraction_cost::contract_path_cost;
8use crate::contractionpath::paths::FindPath;
9use crate::contractionpath::{ssa_replace_ordering, ContractionPath, SimplePath};
10use crate::tensornetwork::tensor::Tensor;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum OptMethod {
15 Optimal,
17 Greedy,
19 RandomGreedy(usize),
21}
22
23#[derive(Debug, Clone)]
25pub struct Cotengrust<'a> {
26 tensor: &'a Tensor,
27 best_path: ContractionPath,
28 best_flops: f64,
29 best_size: f64,
30 opt_method: OptMethod,
31}
32
33impl<'a> Cotengrust<'a> {
34 pub fn new(tensor: &'a Tensor, opt_method: OptMethod) -> Self {
36 Self {
37 tensor,
38 opt_method,
39 best_path: ContractionPath::default(),
40 best_flops: f64::INFINITY,
41 best_size: f64::INFINITY,
42 }
43 }
44
45 fn optimize_single(&self, inputs: &[Tensor], output: &Tensor) -> SimplePath {
48 if inputs.is_empty() {
50 return SimplePath::default();
51 }
52
53 let (inputs, output, size_dict) = tensor_legs_to_digit(inputs, output);
55
56 let path = match &self.opt_method {
58 OptMethod::Greedy => optimize_greedy_rust(
59 inputs,
60 output,
61 size_dict,
62 None,
63 None,
64 None,
65 Some(42),
66 false,
67 true,
68 ),
69 &OptMethod::RandomGreedy(ntrials) => {
70 optimize_random_greedy_rust(
71 inputs,
72 output,
73 size_dict,
74 ntrials,
75 None,
76 None,
77 None,
78 Some(42),
79 false,
80 true,
81 )
82 .0
83 }
84 OptMethod::Optimal => {
85 optimize_optimal_rust(inputs, output, size_dict, None, None, None, false, true)
86 }
87 };
88
89 path.into_iter()
91 .map(|pair| {
92 let [a, b] = pair[..] else {
93 panic!("Expected two indices in contraction path pair")
94 };
95 (a as _, b as _)
96 })
97 .collect_vec()
98 }
99}
100
101fn tensor_legs_to_digit(
103 inputs: &[Tensor],
104 output: &Tensor,
105) -> (Vec<Vec<char>>, Vec<char>, FxHashMap<char, f32>) {
106 fn leg_to_char(leg: usize) -> char {
107 char::from_u32(leg.try_into().unwrap()).unwrap()
108 }
109 let mut new_inputs = vec![Vec::new(); inputs.len()];
110 let new_output = output.legs().iter().copied().map(leg_to_char).collect();
111 let mut new_size_dict = FxHashMap::default();
112
113 for (tensor, labels) in zip(inputs, new_inputs.iter_mut()) {
114 labels.reserve_exact(tensor.legs().len());
115 for (leg, dim) in tensor.edges() {
116 let character = leg_to_char(*leg);
117 labels.push(character);
118 new_size_dict.insert(character, *dim as f32);
119 }
120 }
121 (new_inputs, new_output, new_size_dict)
122}
123
124impl FindPath for Cotengrust<'_> {
125 fn find_path(&mut self) {
126 let mut nested_paths = FxHashMap::default();
128 let mut inputs = self.tensor.tensors().clone();
129 for (index, input_tensor) in inputs.iter_mut().enumerate() {
130 if input_tensor.is_composite() {
131 let mut ct = Cotengrust::new(input_tensor, self.opt_method);
132 ct.find_path();
133 nested_paths.insert(index, ct.get_best_path().clone());
134 *input_tensor = input_tensor.external_tensor();
135 }
136 }
137
138 let external_tensor = self.tensor.external_tensor();
140 let outer_path = self.optimize_single(&inputs, &external_tensor);
141 self.best_path = ContractionPath {
142 nested: nested_paths,
143 toplevel: outer_path,
144 };
145
146 let (op_cost, mem_cost) =
148 contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
149 self.best_size = mem_cost;
150 self.best_flops = op_cost;
151 }
152
153 fn get_best_path(&self) -> &ContractionPath {
154 &self.best_path
155 }
156
157 fn get_best_replace_path(&self) -> ContractionPath {
158 ssa_replace_ordering(&self.best_path)
159 }
160
161 fn get_best_flops(&self) -> f64 {
162 self.best_flops
163 }
164
165 fn get_best_size(&self) -> f64 {
166 self.best_size
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 use crate::path;
175
176 fn setup_simple() -> Tensor {
177 let bond_dims =
178 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
179 Tensor::new_composite(vec![
180 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
181 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
182 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
183 ])
184 }
185
186 fn setup_complex() -> Tensor {
187 let bond_dims = FxHashMap::from_iter([
188 (0, 27),
189 (1, 18),
190 (2, 12),
191 (3, 15),
192 (4, 5),
193 (5, 3),
194 (6, 18),
195 (7, 22),
196 (8, 45),
197 (9, 65),
198 (10, 5),
199 (11, 17),
200 ]);
201 Tensor::new_composite(vec![
202 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
203 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
204 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
205 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
206 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
207 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
208 ])
209 }
210
211 fn setup_simple_inner_product() -> Tensor {
212 let bond_dims =
213 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
214 Tensor::new_composite(vec![
215 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
216 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
217 Tensor::new_from_map(vec![0, 1, 5], &bond_dims),
218 Tensor::new_from_map(vec![1, 6], &bond_dims),
219 ])
220 }
221
222 fn setup_simple_outer_product() -> Tensor {
223 let bond_dims = FxHashMap::from_iter([(0, 3), (1, 2), (2, 2)]);
224 Tensor::new_composite(vec![
225 Tensor::new_from_map(vec![0], &bond_dims),
226 Tensor::new_from_map(vec![1], &bond_dims),
227 Tensor::new_from_map(vec![2], &bond_dims),
228 ])
229 }
230
231 fn setup_complex_outer_product() -> Tensor {
232 let bond_dims = FxHashMap::from_iter([(0, 5), (1, 4)]);
233 Tensor::new_composite(vec![
234 Tensor::new_from_map(vec![0], &bond_dims),
235 Tensor::new_from_map(vec![0], &bond_dims),
236 Tensor::new_from_map(vec![1], &bond_dims),
237 Tensor::new_from_map(vec![1], &bond_dims),
238 ])
239 }
240
241 #[test]
242 fn test_contract_order_greedy_simple() {
243 let tn = setup_simple();
244 let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
245 opt.find_path();
246
247 assert_eq!(opt.get_best_flops(), 600.);
248 assert_eq!(opt.get_best_size(), 538.);
249 assert_eq!(opt.get_best_path(), &path![(0, 1), (3, 2)]);
250 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (0, 2)]);
251 }
252
253 #[test]
254 fn test_contract_order_greedy_simple_inner() {
255 let tn = setup_simple_inner_product();
256 let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
257 opt.find_path();
258
259 assert_eq!(opt.get_best_flops(), 228.);
260 assert_eq!(opt.get_best_size(), 121.);
261 assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3), (4, 5)]);
262 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 3), (0, 2)]);
263 }
264
265 #[test]
266 fn test_contract_order_greedy_simple_outer() {
267 let tn = setup_simple_outer_product();
268 let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
269 opt.find_path();
270
271 assert_eq!(opt.get_best_flops(), 16.);
272 assert_eq!(opt.get_best_size(), 19.);
273 assert_eq!(opt.get_best_path(), &path![(2, 1), (0, 3)]);
274 assert_eq!(opt.get_best_replace_path(), path![(2, 1), (0, 2)]);
275 }
276
277 #[test]
278 fn test_contract_order_greedy_complex_outer() {
279 let tn = setup_complex_outer_product();
280 let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
281 opt.find_path();
282
283 assert_eq!(opt.get_best_flops(), 10.);
284 assert_eq!(opt.get_best_size(), 11.);
285 assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3), (5, 4)]);
286 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 3), (2, 0)]);
287 }
288
289 #[test]
290 fn test_contract_order_greedy_complex() {
291 let tn = setup_complex();
292 let mut opt = Cotengrust::new(&tn, OptMethod::Greedy);
293 opt.find_path();
294
295 assert_eq!(opt.get_best_flops(), 529815.);
296 assert_eq!(opt.get_best_size(), 89478.);
297 assert_eq!(
298 opt.get_best_path(),
299 &path![(1, 5), (3, 4), (6, 0), (7, 2), (9, 8)]
300 );
301 assert_eq!(
302 opt.get_best_replace_path(),
303 path![(1, 5), (3, 4), (1, 0), (3, 2), (3, 1)]
304 );
305 }
306}