tnc/contractionpath/
paths.rs

1//! Contraction path finders.
2
3use crate::contractionpath::ContractionPath;
4
5pub mod branchbound;
6pub mod cotengrust;
7pub mod hyperoptimization;
8#[cfg(feature = "cotengra")]
9pub mod tree_annealing;
10#[cfg(feature = "cotengra")]
11pub mod tree_reconfiguration;
12#[cfg(feature = "cotengra")]
13pub mod tree_tempering;
14pub mod weighted_branchbound;
15
16/// An optimizer for finding a contraction path.
17pub trait FindPath {
18    /// Finds a contraction path.
19    fn find_path(&mut self);
20
21    /// Returns the best found contraction path in SSA format.
22    fn get_best_path(&self) -> &ContractionPath;
23
24    /// Returns the best found contraction path in ReplaceLeft format.
25    fn get_best_replace_path(&self) -> ContractionPath;
26
27    /// Returns the total op count of the best path found.
28    fn get_best_flops(&self) -> f64;
29
30    /// Returns the max memory (in number of elements) of the best path found.
31    fn get_best_size(&self) -> f64;
32}
33
34/// The cost metric to optimize for.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum CostType {
37    /// Number of flops or operations.
38    Flops,
39    /// Size of the biggest contraction.
40    Size,
41}
42
43pub(crate) fn validate_path(path: &ContractionPath) {
44    let mut contracted = Vec::<usize>::new();
45    for nested in path.nested.values() {
46        validate_path(nested);
47    }
48
49    for (u, v) in &path.toplevel {
50        assert!(
51            !contracted.contains(u),
52            "Contracting already contracted tensors: {u:?}, path: {path:?}"
53        );
54        contracted.push(*v);
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    use crate::path;
63
64    #[test]
65    #[should_panic(
66        expected = "Contracting already contracted tensors: 1, path: ContractionPath { nested: {}, toplevel: [(0, 1), (1, 2)] }"
67    )]
68    fn test_validate_paths() {
69        let invalid_path = path![(0, 1), (1, 2)];
70        validate_path(&invalid_path);
71    }
72}