tnc/
contractionpath.rs

1//! Find and evaluate the cost of contraction paths. Also contains functionality for
2//! rebalancing of partitioned tensor networks.
3
4use rustc_hash::FxHashMap;
5use serde::{Deserialize, Serialize};
6
7use crate::tensornetwork::tensor::TensorIndex;
8use crate::utils::traits::{HashMapInsertNew, WithCapacity};
9
10mod candidates;
11pub mod communication_schemes;
12pub mod contraction_cost;
13pub mod contraction_tree;
14pub mod paths;
15pub mod repartitioning;
16
17/// A simple, flat contraction path. If you only need a reference, prefer
18/// [`SimplePathRef`].
19pub type SimplePath = Vec<(TensorIndex, TensorIndex)>;
20
21/// Reference to a [`SimplePath`].
22pub type SimplePathRef<'a> = &'a [(TensorIndex, TensorIndex)];
23
24/// A (possibly nested) contraction path.
25///
26/// It specifies the overall contraction path to contract a tensor network, but also
27/// allows to specify additional contraction paths for each tensor, in order to deal
28/// with composite tensors that have to be contracted first.
29#[derive(Debug, Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
30pub struct ContractionPath {
31    /// Nested contraction paths for composite tensors.
32    pub nested: FxHashMap<TensorIndex, ContractionPath>,
33    /// The top-level contraction path for the tensor network itself.
34    pub toplevel: SimplePath,
35}
36
37impl ContractionPath {
38    /// Creates a contraction path with nested paths.
39    #[inline]
40    pub fn nested(
41        nested: Vec<(TensorIndex, ContractionPath)>,
42        toplevel: Vec<(TensorIndex, TensorIndex)>,
43    ) -> Self {
44        Self {
45            nested: FxHashMap::from_iter(nested),
46            toplevel,
47        }
48    }
49
50    /// Creates a plain contraction path without nested paths.
51    ///
52    /// # Examples
53    /// ```
54    /// # use tnc::contractionpath::{ContractionPath, SimplePath};
55    /// let path: SimplePath = vec![(0, 1), (0, 2), (0, 3)];
56    /// let contraction_path = ContractionPath::simple(path.clone());
57    /// assert!(contraction_path.is_simple());
58    /// assert_eq!(contraction_path.toplevel, path);
59    /// ```
60    #[inline]
61    pub fn simple(path: SimplePath) -> Self {
62        Self {
63            nested: FxHashMap::default(),
64            toplevel: path,
65        }
66    }
67
68    /// Creates a contraction path from a single contraction of two tensors.
69    ///
70    /// # Examples
71    /// ```
72    /// # use tnc::contractionpath::ContractionPath;
73    /// let contraction_path = ContractionPath::single(0, 1);
74    /// assert!(contraction_path.is_simple());
75    /// assert_eq!(contraction_path.toplevel, vec![(0, 1)]);
76    /// ```
77    #[inline]
78    pub fn single(a: TensorIndex, b: TensorIndex) -> Self {
79        Self::simple(vec![(a, b)])
80    }
81
82    /// The length of the contraction path, that is, the number of top-level
83    /// contractions.
84    ///
85    /// # Examples
86    /// ```
87    /// # use tnc::contractionpath::ContractionPath;
88    /// let contraction_path = ContractionPath::simple(vec![(0, 1), (0, 2), (0, 3)]);
89    /// assert_eq!(contraction_path.len(), 3);
90    /// ```
91    #[inline]
92    pub fn len(&self) -> usize {
93        self.toplevel.len()
94    }
95
96    /// Whether there are any top-level contractions in this contraction path.
97    ///
98    /// # Examples
99    /// ```
100    /// # use tnc::contractionpath::ContractionPath;
101    /// assert!(ContractionPath::default().is_empty());
102    /// assert!(!ContractionPath::simple(vec![(0, 1)]).is_empty());
103    /// ```
104    #[inline]
105    pub fn is_empty(&self) -> bool {
106        self.toplevel.is_empty()
107    }
108
109    /// Returns whether this path has no nested paths.
110    ///
111    /// # Examples
112    /// ```
113    /// # use tnc::contractionpath::ContractionPath;
114    /// # use tnc::path;
115    /// let simple_path = path![(0, 1), (0, 2), (0, 3)];
116    /// assert!(simple_path.is_simple());
117    /// let nested_path = path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)];
118    /// assert!(!nested_path.is_simple());
119    /// ```
120    #[inline]
121    pub fn is_simple(&self) -> bool {
122        self.nested.is_empty()
123    }
124
125    /// Converts this path to its toplevel component.
126    ///
127    /// # Panics
128    /// - Panics when this path has nested components
129    ///
130    /// # Examples
131    /// ```
132    /// # use tnc::contractionpath::ContractionPath;
133    /// # use tnc::path;
134    /// let contractions = vec![(0, 1), (0, 2), (0, 3)];
135    /// let simple_path = ContractionPath::simple(contractions.clone());
136    /// assert_eq!(simple_path.into_simple(), contractions);
137    /// ```
138    #[inline]
139    pub fn into_simple(self) -> SimplePath {
140        assert!(self.is_simple());
141        self.toplevel
142    }
143}
144
145/// Macro to create (nested) contraction paths, assuming the left tensor is replaced
146/// in each contraction.
147///
148/// For instance, `path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)]` creates a nested
149/// contraction path that
150/// - recursively contracts the composite tensor 2 with the contraction path `[(0, 2), (0, 1)]`
151/// - contracts tensors 0 and 1, replacing tensor 0 with the result
152/// - contracts tensors 0 and (now contracted) tensor 2, replacing tensor 0 with the result
153#[macro_export]
154macro_rules! path {
155    [] => {
156        $crate::contractionpath::ContractionPath::default()
157    };
158    [$( ($t0:expr, $t1:expr) ),*] => {
159        $crate::contractionpath::ContractionPath::simple(vec![$( ($t0, $t1) ),*])
160    };
161    [ { $( ( $index:expr, [ $( $tok:tt )* ] ) ),* $(,)? } $(, ($t0:expr, $t1:expr) )* ] => {
162        $crate::contractionpath::ContractionPath::nested(
163            vec![ $( ($index, path![ $( $tok )* ]) ),* ],
164            vec![ $( ($t0, $t1) ),* ]
165        )
166    };
167}
168
169/// The contraction ordering labels [`Tensor`] objects from each possible contraction with a
170/// unique identifier in SSA format. As only a subset of these [`Tensor`] objects are seen in
171/// a contraction path, the tensors in the optimal path search are not sequential. This converts
172/// the output to strictly obey an SSA format.
173///
174/// # Arguments
175/// * `path` - Output path as `&[(usize, usize, usize)]` after an `find_path` call.
176/// * `n` - Number of initial input tensors.
177///
178/// # Returns
179/// Identical path using SSA format
180fn ssa_ordering(path: &[(usize, usize, usize)], mut n: usize) -> ContractionPath {
181    let mut ssa_path = Vec::with_capacity(path.len());
182    let mut hs = FxHashMap::with_capacity(path.len());
183    let path_len = n;
184    for (u1, u2, u3) in path {
185        let t1 = if *u1 >= path_len { hs[u1] } else { *u1 };
186        let t2 = if *u2 >= path_len { hs[u2] } else { *u2 };
187        hs.entry(*u3).or_insert(n);
188        n += 1;
189        ssa_path.push((t1, t2));
190    }
191    ContractionPath::simple(ssa_path)
192}
193
194/// Accepts a contraction `path` that is in SSA format and returns a contraction path
195/// assuming that all contracted tensors replace the left input tensor and no tensor
196/// is popped.
197pub(super) fn ssa_replace_ordering(path: &ContractionPath) -> ContractionPath {
198    let nested = path
199        .nested
200        .iter()
201        .map(|(index, local_path)| (*index, ssa_replace_ordering(local_path)))
202        .collect();
203
204    let mut hs = FxHashMap::with_capacity(path.len());
205    let mut toplevel = Vec::with_capacity(path.len());
206    let mut n = path.len() + 1;
207    for (t0, t1) in &path.toplevel {
208        let new_t0 = *hs.get(t0).unwrap_or(t0);
209        let new_t1 = *hs.get(t1).unwrap_or(t1);
210
211        hs.insert_new(n, new_t0);
212        toplevel.push((new_t0, new_t1));
213        n += 1;
214    }
215
216    ContractionPath { nested, toplevel }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_path_simple_macro() {
225        assert_eq!(
226            path![{ (2, [(1, 2), (1, 3)]) }, (0, 1)],
227            ContractionPath {
228                nested: FxHashMap::from_iter([(2, ContractionPath::simple(vec![(1, 2), (1, 3)]))]),
229                toplevel: vec![(0, 1)]
230            }
231        );
232    }
233
234    #[test]
235    fn test_path_macro() {
236        assert_eq!(
237            path![
238                {
239                (2, [(1, 2), (1, 3)]),
240                (4, [{(2, [(1, 2), (1, 3)])}, (1, 3)]),
241                (5, [(1, 2), (1, 3)]),
242                (3, [(4, 1), (3, 4), (3, 5)]),
243                },
244                (0, 1),
245                (0, 2),
246                (0, 3)
247            ],
248            ContractionPath {
249                nested: FxHashMap::from_iter([
250                    (2, ContractionPath::simple(vec![(1, 2), (1, 3)])),
251                    (3, ContractionPath::simple(vec![(4, 1), (3, 4), (3, 5)])),
252                    (
253                        4,
254                        ContractionPath {
255                            nested: FxHashMap::from_iter([(
256                                2,
257                                ContractionPath::simple(vec![(1, 2), (1, 3)])
258                            )]),
259                            toplevel: vec![(1, 3)]
260                        }
261                    ),
262                    (5, ContractionPath::simple(vec![(1, 2), (1, 3)]))
263                ]),
264                toplevel: vec![(0, 1), (0, 2), (0, 3)]
265            }
266        );
267    }
268
269    #[test]
270    fn test_ssa_ordering() {
271        let path = vec![
272            (0, 3, 15),
273            (1, 2, 44),
274            (6, 4, 8),
275            (5, 15, 22),
276            (8, 44, 12),
277            (12, 22, 99),
278        ];
279        let new_path = ssa_ordering(&path, 7);
280
281        assert_eq!(
282            new_path,
283            path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)]
284        );
285    }
286
287    #[test]
288    fn test_ssa_replace_ordering() {
289        let path = path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)];
290        let new_path = ssa_replace_ordering(&path);
291
292        assert_eq!(
293            new_path,
294            path![(0, 3), (1, 2), (6, 4), (5, 0), (6, 1), (6, 5)]
295        );
296    }
297
298    #[test]
299    fn test_ssa_replace_ordering_nested() {
300        let path = path![
301            {
302            (1, [(2, 1), (0, 3)]),
303            (6, [(0, 2), (1, 3), (4, 5)])
304            },
305            (0, 3),
306            (1, 2),
307            (6, 4),
308            (5, 7),
309            (9, 8),
310            (11, 10)
311        ];
312
313        let new_path = ssa_replace_ordering(&path);
314
315        assert_eq!(
316            new_path,
317            path![
318                {
319                (1, [(2, 1), (0, 2)]),
320                (6, [(0, 2), (1, 3), (0, 1)])
321                },
322                (0, 3),
323                (1, 2),
324                (6, 4),
325                (5, 0),
326                (6, 1),
327                (6, 5)
328            ]
329        );
330    }
331}