tnc/contractionpath/
paths.rs

1//! Contraction path finders.
2
3use crate::contractionpath::ContractionPath;
4
5pub mod branchbound;
6pub mod cotengrust;
7#[cfg(feature = "cotengra")]
8pub mod hyperoptimization;
9#[cfg(feature = "cotengra")]
10pub mod tree_annealing;
11#[cfg(feature = "cotengra")]
12pub mod tree_reconfiguration;
13#[cfg(feature = "cotengra")]
14pub mod tree_tempering;
15pub mod weighted_branchbound;
16
17/// An optimizer for finding a contraction path.
18pub trait FindPath {
19    /// Finds a contraction path.
20    fn find_path(&mut self);
21
22    /// Returns the best found contraction path in SSA format.
23    fn get_best_path(&self) -> &ContractionPath;
24
25    /// Returns the best found contraction path in ReplaceLeft format.
26    fn get_best_replace_path(&self) -> ContractionPath;
27
28    /// Returns the total op count of the best path found.
29    fn get_best_flops(&self) -> f64;
30
31    /// Returns the max memory (in number of elements) of the best path found.
32    fn get_best_size(&self) -> f64;
33}
34
35/// The cost metric to optimize for.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum CostType {
38    /// Number of flops or operations.
39    Flops,
40    /// Size of the biggest contraction.
41    Size,
42}
43
44pub(crate) fn validate_path(path: &ContractionPath) {
45    let mut contracted = Vec::<usize>::new();
46    for nested in path.nested.values() {
47        validate_path(nested);
48    }
49
50    for (u, v) in &path.toplevel {
51        assert!(
52            !contracted.contains(u),
53            "Contracting already contracted tensors: {u:?}, path: {path:?}"
54        );
55        contracted.push(*v);
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    use crate::path;
64
65    #[test]
66    #[should_panic(
67        expected = "Contracting already contracted tensors: 1, path: ContractionPath { nested: {}, toplevel: [(0, 1), (1, 2)] }"
68    )]
69    fn test_validate_paths() {
70        let invalid_path = path![(0, 1), (1, 2)];
71        validate_path(&invalid_path);
72    }
73}