1use std::{
2 iter::zip,
3 process::{Command, Stdio},
4 time::Duration,
5};
6
7use rustc_hash::FxHashMap;
8use serde::Serialize;
9use serde_pickle::{DeOptions, SerOptions};
10
11use crate::{
12 contractionpath::{
13 contraction_cost::contract_path_cost,
14 paths::{CostType, FindPath},
15 ssa_replace_ordering, ContractionPath,
16 },
17 tensornetwork::tensor::Tensor,
18};
19
20pub struct Hyperoptimizer<'a> {
23 tensor: &'a Tensor,
24 hyper_options: HyperOptions,
25 best_flops: f64,
26 best_size: f64,
27 best_path: ContractionPath,
28}
29
30impl<'a> Hyperoptimizer<'a> {
31 pub fn new(tensor: &'a Tensor, minimize: CostType, hyper_options: HyperOptions) -> Self {
32 assert_eq!(
33 minimize,
34 CostType::Flops,
35 "Currently, only Flops is supported"
36 );
37 Self {
38 tensor,
39 hyper_options,
40 best_flops: f64::INFINITY,
41 best_size: f64::INFINITY,
42 best_path: ContractionPath::default(),
43 }
44 }
45}
46
47#[derive(Serialize, Default)]
53pub struct HyperOptions {
54 max_time: Option<u64>,
55 max_repeats: Option<usize>,
56}
57
58impl HyperOptions {
59 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn with_max_time(mut self, time: &Duration) -> Self {
66 self.max_time = Some(time.as_secs());
67 self
68 }
69
70 pub fn with_max_repeats(mut self, repeats: usize) -> Self {
72 self.max_repeats = Some(repeats);
73 self
74 }
75}
76
77fn python_hyperoptimizer(
85 inputs: &[Vec<char>],
86 outputs: &[char],
87 size_dict: &FxHashMap<char, f32>,
88 hyper_options: &HyperOptions,
89) -> Vec<(usize, usize)> {
90 const PYTHON_CODE: &str = include_str!("hyperoptimization.py");
94
95 let mut child = Command::new("python3")
97 .arg("-c")
98 .arg(PYTHON_CODE)
99 .stdin(Stdio::piped())
100 .stdout(Stdio::piped())
101 .spawn()
102 .unwrap();
103 let mut stdin = child.stdin.take().unwrap();
104
105 serde_pickle::to_writer(
107 &mut stdin,
108 &(inputs, outputs, size_dict, hyper_options),
109 SerOptions::default(),
110 )
111 .unwrap();
112
113 let out = child.wait_with_output().unwrap();
115
116 serde_pickle::from_slice(&out.stdout, DeOptions::default()).unwrap()
118}
119
120fn tensor_legs_to_chars(
122 inputs: &[Tensor],
123 output: &Tensor,
124) -> (Vec<Vec<char>>, Vec<char>, FxHashMap<char, f32>) {
125 fn leg_to_char(leg: usize) -> char {
126 char::from_u32(leg.try_into().unwrap()).unwrap()
127 }
128 let mut new_inputs = vec![Vec::new(); inputs.len()];
129 let new_output = output.legs().iter().copied().map(leg_to_char).collect();
130 let mut new_size_dict = FxHashMap::default();
131
132 for (tensor, labels) in zip(inputs, new_inputs.iter_mut()) {
133 labels.reserve_exact(tensor.legs().len());
134 for (leg, dim) in tensor.edges() {
135 let character = leg_to_char(*leg);
136 labels.push(character);
137 new_size_dict.insert(character, *dim as f32);
138 }
139 }
140 (new_inputs, new_output, new_size_dict)
141}
142
143impl FindPath for Hyperoptimizer<'_> {
144 fn find_path(&mut self) {
145 let (inputs, outputs, size_dict) =
146 tensor_legs_to_chars(self.tensor.tensors(), &self.tensor.external_tensor());
147
148 let ssa_path = python_hyperoptimizer(&inputs, &outputs, &size_dict, &self.hyper_options);
149
150 self.best_path = ContractionPath::simple(ssa_path);
151
152 let (op_cost, mem_cost) =
153 contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
154
155 self.best_flops = op_cost;
156 self.best_size = mem_cost;
157 }
158
159 fn get_best_flops(&self) -> f64 {
160 self.best_flops
161 }
162
163 fn get_best_size(&self) -> f64 {
164 self.best_size
165 }
166
167 fn get_best_path(&self) -> &ContractionPath {
168 &self.best_path
169 }
170
171 fn get_best_replace_path(&self) -> ContractionPath {
172 ssa_replace_ordering(&self.best_path)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 use std::time::Duration;
181
182 use rustc_hash::FxHashMap;
183
184 use crate::{
185 contractionpath::paths::{CostType, FindPath},
186 path,
187 tensornetwork::tensor::Tensor,
188 };
189
190 fn setup_simple() -> Tensor {
191 let bond_dims =
192 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
193 Tensor::new_composite(vec![
194 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
195 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
196 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
197 ])
198 }
199
200 fn setup_complex() -> Tensor {
201 let bond_dims = FxHashMap::from_iter([
202 (0, 27),
203 (1, 18),
204 (2, 12),
205 (3, 15),
206 (4, 5),
207 (5, 3),
208 (6, 18),
209 (7, 22),
210 (8, 45),
211 (9, 65),
212 (10, 5),
213 (11, 17),
214 ]);
215 Tensor::new_composite(vec![
216 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
217 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
218 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
219 Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
220 Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
221 Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
222 ])
223 }
224
225 #[test]
226 fn test_hyper_tree_contract_order_simple() {
227 let tn = setup_simple();
228 let mut opt = Hyperoptimizer::new(
229 &tn,
230 CostType::Flops,
231 HyperOptions::new().with_max_time(&Duration::from_secs(25)),
232 );
233 opt.find_path();
234
235 assert_eq!(opt.best_flops, 600.);
236 assert_eq!(opt.best_size, 538.);
237 assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
238 assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
239 }
240
241 #[test]
242 fn test_hyper_tree_contract_order_complex() {
243 let tn = setup_complex();
244 let mut opt = Hyperoptimizer::new(
245 &tn,
246 CostType::Flops,
247 HyperOptions::new().with_max_time(&Duration::from_secs(45)),
248 );
249 opt.find_path();
250
251 assert_eq!(opt.best_flops, 529815.);
252 assert_eq!(opt.best_size, 89478.);
253 assert_eq!(opt.best_path, path![(1, 5), (0, 6), (3, 4), (2, 8), (7, 9)]);
254 assert_eq!(
255 opt.get_best_replace_path(),
256 path![(1, 5), (0, 1), (3, 4), (2, 3), (0, 2)]
257 );
258 }
259}