tnc/
contractionpath.rs

1use rustc_hash::FxHashMap;
2use serde::{Deserialize, Serialize};
3
4use crate::tensornetwork::tensor::TensorIndex;
5use crate::utils::traits::{HashMapInsertNew, WithCapacity};
6
7mod candidates;
8pub mod communication_schemes;
9pub mod contraction_cost;
10pub mod contraction_tree;
11pub mod paths;
12pub mod repartitioning;
13
14/// A simple, flat contraction path. If you only need a reference, prefer
15/// [`SimplePathRef`].
16pub type SimplePath = Vec<(TensorIndex, TensorIndex)>;
17
18/// Reference to a [`SimplePath`].
19pub type SimplePathRef<'a> = &'a [(TensorIndex, TensorIndex)];
20
21/// A (possibly nested) contraction path.
22///
23/// It specifies the overall contraction path to contract a tensor network, but also
24/// allows to specify additional contraction paths for each tensor, in order to deal
25/// with composite tensors that have to be contracted first.
26#[derive(Debug, Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
27pub struct ContractionPath {
28    /// Nested contraction paths for composite tensors.
29    pub nested: FxHashMap<TensorIndex, ContractionPath>,
30    /// The top-level contraction path for the tensor network itself.
31    pub toplevel: SimplePath,
32}
33
34impl ContractionPath {
35    /// Creates a contraction path with nested paths.
36    #[inline]
37    pub fn nested(
38        nested: Vec<(TensorIndex, ContractionPath)>,
39        toplevel: Vec<(TensorIndex, TensorIndex)>,
40    ) -> Self {
41        Self {
42            nested: FxHashMap::from_iter(nested),
43            toplevel,
44        }
45    }
46
47    /// Creates a plain contraction path without nested paths.
48    ///
49    /// # Examples
50    /// ```
51    /// # use tnc::contractionpath::{ContractionPath, SimplePath};
52    /// let path: SimplePath = vec![(0, 1), (0, 2), (0, 3)];
53    /// let contraction_path = ContractionPath::simple(path.clone());
54    /// assert!(contraction_path.is_simple());
55    /// assert_eq!(contraction_path.toplevel, path);
56    /// ```
57    #[inline]
58    pub fn simple(path: SimplePath) -> Self {
59        Self {
60            nested: FxHashMap::default(),
61            toplevel: path,
62        }
63    }
64
65    /// Creates a contraction path from a single contraction of two tensors.
66    ///
67    /// # Examples
68    /// ```
69    /// # use tnc::contractionpath::ContractionPath;
70    /// let contraction_path = ContractionPath::single(0, 1);
71    /// assert!(contraction_path.is_simple());
72    /// assert_eq!(contraction_path.toplevel, vec![(0, 1)]);
73    /// ```
74    #[inline]
75    pub fn single(a: TensorIndex, b: TensorIndex) -> Self {
76        Self::simple(vec![(a, b)])
77    }
78
79    /// The length of the contraction path, that is, the number of top-level
80    /// contractions.
81    ///
82    /// # Examples
83    /// ```
84    /// # use tnc::contractionpath::ContractionPath;
85    /// let contraction_path = ContractionPath::simple(vec![(0, 1), (0, 2), (0, 3)]);
86    /// assert_eq!(contraction_path.len(), 3);
87    /// ```
88    #[inline]
89    pub fn len(&self) -> usize {
90        self.toplevel.len()
91    }
92
93    /// Whether there are any top-level contractions in this contraction path.
94    ///
95    /// # Examples
96    /// ```
97    /// # use tnc::contractionpath::ContractionPath;
98    /// assert!(ContractionPath::default().is_empty());
99    /// assert!(!ContractionPath::simple(vec![(0, 1)]).is_empty());
100    /// ```
101    #[inline]
102    pub fn is_empty(&self) -> bool {
103        self.toplevel.is_empty()
104    }
105
106    /// Returns whether this path has no nested paths.
107    ///
108    /// # Examples
109    /// ```
110    /// # use tnc::contractionpath::ContractionPath;
111    /// # use tnc::path;
112    /// let simple_path = path![(0, 1), (0, 2), (0, 3)];
113    /// assert!(simple_path.is_simple());
114    /// let nested_path = path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)];
115    /// assert!(!nested_path.is_simple());
116    /// ```
117    #[inline]
118    pub fn is_simple(&self) -> bool {
119        self.nested.is_empty()
120    }
121
122    /// Converts this path to its toplevel component.
123    ///
124    /// # Panics
125    /// - Panics when this path has nested components
126    ///
127    /// # Examples
128    /// ```
129    /// # use tnc::contractionpath::ContractionPath;
130    /// # use tnc::path;
131    /// let contractions = vec![(0, 1), (0, 2), (0, 3)];
132    /// let simple_path = ContractionPath::simple(contractions.clone());
133    /// assert_eq!(simple_path.into_simple(), contractions);
134    /// ```
135    #[inline]
136    pub fn into_simple(self) -> SimplePath {
137        assert!(self.is_simple());
138        self.toplevel
139    }
140}
141
142/// Macro to create (nested) contraction paths, assuming the left tensor is replaced
143/// in each contraction.
144///
145/// For instance, `path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)]` creates a nested
146/// contraction path that
147/// - recursively contracts the composite tensor 2 with the contraction path `[(0, 2), (0, 1)]`
148/// - contracts tensors 0 and 1, replacing tensor 0 with the result
149/// - contracts tensors 0 and (now contracted) tensor 2, replacing tensor 0 with the result
150#[macro_export]
151macro_rules! path {
152    [] => {
153        $crate::contractionpath::ContractionPath::default()
154    };
155    [$( ($t0:expr, $t1:expr) ),*] => {
156        $crate::contractionpath::ContractionPath::simple(vec![$( ($t0, $t1) ),*])
157    };
158    [ { $( ( $index:expr, [ $( $tok:tt )* ] ) ),* $(,)? } $(, ($t0:expr, $t1:expr) )* ] => {
159        $crate::contractionpath::ContractionPath::nested(
160            vec![ $( ($index, path![ $( $tok )* ]) ),* ],
161            vec![ $( ($t0, $t1) ),* ]
162        )
163    };
164}
165
166/// The contraction ordering labels [`Tensor`] objects from each possible contraction with a
167/// unique identifier in SSA format. As only a subset of these [`Tensor`] objects are seen in
168/// a contraction path, the tensors in the optimal path search are not sequential. This converts
169/// the output to strictly obey an SSA format.
170///
171/// # Arguments
172/// * `path` - Output path as `&[(usize, usize, usize)]` after an `find_path` call.
173/// * `n` - Number of initial input tensors.
174///
175/// # Returns
176/// Identical path using SSA format
177fn ssa_ordering(path: &[(usize, usize, usize)], mut n: usize) -> ContractionPath {
178    let mut ssa_path = Vec::with_capacity(path.len());
179    let mut hs = FxHashMap::with_capacity(path.len());
180    let path_len = n;
181    for (u1, u2, u3) in path {
182        let t1 = if *u1 >= path_len { hs[u1] } else { *u1 };
183        let t2 = if *u2 >= path_len { hs[u2] } else { *u2 };
184        hs.entry(*u3).or_insert(n);
185        n += 1;
186        ssa_path.push((t1, t2));
187    }
188    ContractionPath::simple(ssa_path)
189}
190
191/// Accepts a contraction `path` that is in SSA format and returns a contraction path
192/// assuming that all contracted tensors replace the left input tensor and no tensor
193/// is popped.
194pub(super) fn ssa_replace_ordering(path: &ContractionPath) -> ContractionPath {
195    let nested = path
196        .nested
197        .iter()
198        .map(|(index, local_path)| (*index, ssa_replace_ordering(local_path)))
199        .collect();
200
201    let mut hs = FxHashMap::with_capacity(path.len());
202    let mut toplevel = Vec::with_capacity(path.len());
203    let mut n = path.len() + 1;
204    for (t0, t1) in &path.toplevel {
205        let new_t0 = *hs.get(t0).unwrap_or(t0);
206        let new_t1 = *hs.get(t1).unwrap_or(t1);
207
208        hs.insert_new(n, new_t0);
209        toplevel.push((new_t0, new_t1));
210        n += 1;
211    }
212
213    ContractionPath { nested, toplevel }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_path_simple_macro() {
222        assert_eq!(
223            path![{ (2, [(1, 2), (1, 3)]) }, (0, 1)],
224            ContractionPath {
225                nested: FxHashMap::from_iter([(2, ContractionPath::simple(vec![(1, 2), (1, 3)]))]),
226                toplevel: vec![(0, 1)]
227            }
228        );
229    }
230
231    #[test]
232    fn test_path_macro() {
233        assert_eq!(
234            path![
235                {
236                (2, [(1, 2), (1, 3)]),
237                (4, [{(2, [(1, 2), (1, 3)])}, (1, 3)]),
238                (5, [(1, 2), (1, 3)]),
239                (3, [(4, 1), (3, 4), (3, 5)]),
240                },
241                (0, 1),
242                (0, 2),
243                (0, 3)
244            ],
245            ContractionPath {
246                nested: FxHashMap::from_iter([
247                    (2, ContractionPath::simple(vec![(1, 2), (1, 3)])),
248                    (3, ContractionPath::simple(vec![(4, 1), (3, 4), (3, 5)])),
249                    (
250                        4,
251                        ContractionPath {
252                            nested: FxHashMap::from_iter([(
253                                2,
254                                ContractionPath::simple(vec![(1, 2), (1, 3)])
255                            )]),
256                            toplevel: vec![(1, 3)]
257                        }
258                    ),
259                    (5, ContractionPath::simple(vec![(1, 2), (1, 3)]))
260                ]),
261                toplevel: vec![(0, 1), (0, 2), (0, 3)]
262            }
263        );
264    }
265
266    #[test]
267    fn test_ssa_ordering() {
268        let path = vec![
269            (0, 3, 15),
270            (1, 2, 44),
271            (6, 4, 8),
272            (5, 15, 22),
273            (8, 44, 12),
274            (12, 22, 99),
275        ];
276        let new_path = ssa_ordering(&path, 7);
277
278        assert_eq!(
279            new_path,
280            path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)]
281        );
282    }
283
284    #[test]
285    fn test_ssa_replace_ordering() {
286        let path = path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)];
287        let new_path = ssa_replace_ordering(&path);
288
289        assert_eq!(
290            new_path,
291            path![(0, 3), (1, 2), (6, 4), (5, 0), (6, 1), (6, 5)]
292        );
293    }
294
295    #[test]
296    fn test_ssa_replace_ordering_nested() {
297        let path = path![
298            {
299            (1, [(2, 1), (0, 3)]),
300            (6, [(0, 2), (1, 3), (4, 5)])
301            },
302            (0, 3),
303            (1, 2),
304            (6, 4),
305            (5, 7),
306            (9, 8),
307            (11, 10)
308        ];
309
310        let new_path = ssa_replace_ordering(&path);
311
312        assert_eq!(
313            new_path,
314            path![
315                {
316                (1, [(2, 1), (0, 2)]),
317                (6, [(0, 2), (1, 3), (0, 1)])
318                },
319                (0, 3),
320                (1, 2),
321                (6, 4),
322                (5, 0),
323                (6, 1),
324                (6, 5)
325            ]
326        );
327    }
328}