tnc/contractionpath/paths/
hyperoptimization.rs

1use std::{
2    iter::zip,
3    process::{Command, Stdio},
4    time::Duration,
5};
6
7use rustc_hash::FxHashMap;
8use serde::Serialize;
9use serde_pickle::{DeOptions, SerOptions};
10
11use crate::{
12    contractionpath::{
13        contraction_cost::contract_path_cost,
14        paths::{CostType, FindPath},
15        ssa_replace_ordering, ContractionPath,
16    },
17    tensornetwork::tensor::Tensor,
18};
19
20/// Creates an interface to access `Cotengra` methods in Rust. Specifically exposes
21/// `search` method of `HyperOptimizer`.
22pub struct Hyperoptimizer<'a> {
23    tensor: &'a Tensor,
24    hyper_options: HyperOptions,
25    best_flops: f64,
26    best_size: f64,
27    best_path: ContractionPath,
28}
29
30impl<'a> Hyperoptimizer<'a> {
31    pub fn new(tensor: &'a Tensor, minimize: CostType, hyper_options: HyperOptions) -> Self {
32        assert_eq!(
33            minimize,
34            CostType::Flops,
35            "Currently, only Flops is supported"
36        );
37        Self {
38            tensor,
39            hyper_options,
40            best_flops: f64::INFINITY,
41            best_size: f64::INFINITY,
42            best_path: ContractionPath::default(),
43        }
44    }
45}
46
47/// The keyword options for the cotengra Hyperoptimizer.
48///
49/// Unassigned options will not be passed to the function and hence the Python
50/// default values will be used. Please see the cotengra documentation for details on
51/// the parameters.
52#[derive(Serialize, Default)]
53pub struct HyperOptions {
54    max_time: Option<u64>,
55    max_repeats: Option<usize>,
56}
57
58impl HyperOptions {
59    /// Creates the default HyperOptimizer options.
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    /// Sets the `max_time` argument for the HyperOptimizer.
65    pub fn with_max_time(mut self, time: &Duration) -> Self {
66        self.max_time = Some(time.as_secs());
67        self
68    }
69
70    /// Sets the `max_repeats` argument for the HyperOptimizer.
71    pub fn with_max_repeats(mut self, repeats: usize) -> Self {
72        self.max_repeats = Some(repeats);
73        self
74    }
75}
76
77/// Runs the Hyperoptimizer of cotengra on the given inputs. Additional inputs to the
78/// Hyperoptimizer can be passed with the [`HyperOptions`] struct.
79///
80/// # Python Dependency
81/// Python 3 must be installed with `cotengra` and `kahypar` packages installed.
82/// Can also work with virtual environments if the binary is run from a terminal with
83/// actived virtual environment.
84fn python_hyperoptimizer(
85    inputs: &[Vec<char>],
86    outputs: &[char],
87    size_dict: &FxHashMap<char, f32>,
88    hyper_options: &HyperOptions,
89) -> Vec<(usize, usize)> {
90    // Python code to be executed (WARNING: command line length limits might silently
91    // truncate the code! These are usually around >100,000 characters. Make sure the
92    // code is not too long.)
93    const PYTHON_CODE: &str = include_str!("hyperoptimization.py");
94
95    // Spawn python process
96    let mut child = Command::new("python3")
97        .arg("-c")
98        .arg(PYTHON_CODE)
99        .stdin(Stdio::piped())
100        .stdout(Stdio::piped())
101        .spawn()
102        .unwrap();
103    let mut stdin = child.stdin.take().unwrap();
104
105    // Send serialized data
106    serde_pickle::to_writer(
107        &mut stdin,
108        &(inputs, outputs, size_dict, hyper_options),
109        SerOptions::default(),
110    )
111    .unwrap();
112
113    // Wait for completion
114    let out = child.wait_with_output().unwrap();
115
116    // Deserialize SSA path
117    serde_pickle::from_slice(&out.stdout, DeOptions::default()).unwrap()
118}
119
120/// Converts tensor leg inputs to chars. Creates new inputs, outputs and size_dict that can be fed to Cotengra.
121fn tensor_legs_to_chars(
122    inputs: &[Tensor],
123    output: &Tensor,
124) -> (Vec<Vec<char>>, Vec<char>, FxHashMap<char, f32>) {
125    fn leg_to_char(leg: usize) -> char {
126        char::from_u32(leg.try_into().unwrap()).unwrap()
127    }
128    let mut new_inputs = vec![Vec::new(); inputs.len()];
129    let new_output = output.legs().iter().copied().map(leg_to_char).collect();
130    let mut new_size_dict = FxHashMap::default();
131
132    for (tensor, labels) in zip(inputs, new_inputs.iter_mut()) {
133        labels.reserve_exact(tensor.legs().len());
134        for (leg, dim) in tensor.edges() {
135            let character = leg_to_char(*leg);
136            labels.push(character);
137            new_size_dict.insert(character, *dim as f32);
138        }
139    }
140    (new_inputs, new_output, new_size_dict)
141}
142
143impl FindPath for Hyperoptimizer<'_> {
144    fn find_path(&mut self) {
145        let (inputs, outputs, size_dict) =
146            tensor_legs_to_chars(self.tensor.tensors(), &self.tensor.external_tensor());
147
148        let ssa_path = python_hyperoptimizer(&inputs, &outputs, &size_dict, &self.hyper_options);
149
150        self.best_path = ContractionPath::simple(ssa_path);
151
152        let (op_cost, mem_cost) =
153            contract_path_cost(self.tensor.tensors(), &self.get_best_replace_path(), true);
154
155        self.best_flops = op_cost;
156        self.best_size = mem_cost;
157    }
158
159    fn get_best_flops(&self) -> f64 {
160        self.best_flops
161    }
162
163    fn get_best_size(&self) -> f64 {
164        self.best_size
165    }
166
167    fn get_best_path(&self) -> &ContractionPath {
168        &self.best_path
169    }
170
171    fn get_best_replace_path(&self) -> ContractionPath {
172        ssa_replace_ordering(&self.best_path)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    use std::time::Duration;
181
182    use rustc_hash::FxHashMap;
183
184    use crate::{
185        contractionpath::paths::{CostType, FindPath},
186        path,
187        tensornetwork::tensor::Tensor,
188    };
189
190    fn setup_simple() -> Tensor {
191        let bond_dims =
192            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
193        Tensor::new_composite(vec![
194            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
195            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
196            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
197        ])
198    }
199
200    fn setup_complex() -> Tensor {
201        let bond_dims = FxHashMap::from_iter([
202            (0, 27),
203            (1, 18),
204            (2, 12),
205            (3, 15),
206            (4, 5),
207            (5, 3),
208            (6, 18),
209            (7, 22),
210            (8, 45),
211            (9, 65),
212            (10, 5),
213            (11, 17),
214        ]);
215        Tensor::new_composite(vec![
216            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
217            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
218            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
219            Tensor::new_from_map(vec![6, 8, 9], &bond_dims),
220            Tensor::new_from_map(vec![10, 8, 9], &bond_dims),
221            Tensor::new_from_map(vec![5, 1, 0], &bond_dims),
222        ])
223    }
224
225    #[test]
226    fn test_hyper_tree_contract_order_simple() {
227        let tn = setup_simple();
228        let mut opt = Hyperoptimizer::new(
229            &tn,
230            CostType::Flops,
231            HyperOptions::new().with_max_time(&Duration::from_secs(25)),
232        );
233        opt.find_path();
234
235        assert_eq!(opt.best_flops, 600.);
236        assert_eq!(opt.best_size, 538.);
237        assert_eq!(opt.get_best_path(), &path![(0, 1), (2, 3)]);
238        assert_eq!(opt.get_best_replace_path(), path![(0, 1), (2, 0)]);
239    }
240
241    #[test]
242    fn test_hyper_tree_contract_order_complex() {
243        let tn = setup_complex();
244        let mut opt = Hyperoptimizer::new(
245            &tn,
246            CostType::Flops,
247            HyperOptions::new().with_max_time(&Duration::from_secs(45)),
248        );
249        opt.find_path();
250
251        assert_eq!(opt.best_flops, 529815.);
252        assert_eq!(opt.best_size, 89478.);
253        assert_eq!(opt.best_path, path![(1, 5), (0, 6), (3, 4), (2, 8), (7, 9)]);
254        assert_eq!(
255            opt.get_best_replace_path(),
256            path![(1, 5), (0, 1), (3, 4), (2, 3), (0, 2)]
257        );
258    }
259}