Skip to main content

tnc/
contractionpath.rs

1//! Find and evaluate the cost of contraction paths. Also contains functionality for
2//! rebalancing of partitioned tensor networks.
3
4use rustc_hash::FxHashMap;
5use serde::{Deserialize, Serialize};
6
7use crate::tensornetwork::tensor::TensorIndex;
8use crate::utils::traits::{HashMapInsertNew, WithCapacity};
9
10mod candidates;
11pub mod communication_schemes;
12pub mod contraction_cost;
13pub mod contraction_tree;
14pub mod paths;
15pub mod repartitioning;
16
17/// A simple, flat contraction path. If you only need a reference, prefer
18/// [`SimplePathRef`].
19pub type SimplePath = Vec<(TensorIndex, TensorIndex)>;
20
21/// Reference to a [`SimplePath`].
22pub type SimplePathRef<'a> = &'a [(TensorIndex, TensorIndex)];
23
24/// A (possibly nested) contraction path.
25///
26/// It specifies the overall contraction path to contract a tensor network, but also
27/// allows to specify additional contraction paths for each tensor, in order to deal
28/// with composite tensors that have to be contracted first.
29#[derive(Debug, Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
30pub struct ContractionPath {
31    /// Nested contraction paths for composite tensors.
32    pub nested: FxHashMap<TensorIndex, ContractionPath>,
33    /// The top-level contraction path for the tensor network itself.
34    pub toplevel: SimplePath,
35}
36
37impl ContractionPath {
38    /// Creates a contraction path with nested paths.
39    #[inline]
40    pub fn nested(
41        nested: Vec<(TensorIndex, ContractionPath)>,
42        toplevel: Vec<(TensorIndex, TensorIndex)>,
43    ) -> Self {
44        Self {
45            nested: FxHashMap::from_iter(nested),
46            toplevel,
47        }
48    }
49
50    /// Creates a plain contraction path without nested paths.
51    ///
52    /// # Examples
53    /// ```
54    /// # use tnc::contractionpath::{ContractionPath, SimplePath};
55    /// let path: SimplePath = vec![(0, 1), (0, 2), (0, 3)];
56    /// let contraction_path = ContractionPath::simple(path.clone());
57    /// assert!(contraction_path.is_simple());
58    /// assert_eq!(contraction_path.toplevel, path);
59    /// ```
60    #[inline]
61    pub fn simple(path: SimplePath) -> Self {
62        Self {
63            nested: FxHashMap::default(),
64            toplevel: path,
65        }
66    }
67
68    /// Creates a contraction path from a single contraction of two tensors.
69    ///
70    /// # Examples
71    /// ```
72    /// # use tnc::contractionpath::ContractionPath;
73    /// let contraction_path = ContractionPath::single(0, 1);
74    /// assert!(contraction_path.is_simple());
75    /// assert_eq!(contraction_path.toplevel, vec![(0, 1)]);
76    /// ```
77    #[inline]
78    pub fn single(a: TensorIndex, b: TensorIndex) -> Self {
79        Self::simple(vec![(a, b)])
80    }
81
82    /// The length of the contraction path, that is, the number of top-level
83    /// contractions.
84    ///
85    /// # Examples
86    /// ```
87    /// # use tnc::contractionpath::ContractionPath;
88    /// let contraction_path = ContractionPath::simple(vec![(0, 1), (0, 2), (0, 3)]);
89    /// assert_eq!(contraction_path.len(), 3);
90    /// ```
91    #[inline]
92    pub fn len(&self) -> usize {
93        self.toplevel.len()
94    }
95
96    /// Whether there are any top-level contractions in this contraction path.
97    ///
98    /// # Examples
99    /// ```
100    /// # use tnc::contractionpath::ContractionPath;
101    /// assert!(ContractionPath::default().is_empty());
102    /// assert!(!ContractionPath::simple(vec![(0, 1)]).is_empty());
103    /// ```
104    #[inline]
105    pub fn is_empty(&self) -> bool {
106        self.toplevel.is_empty()
107    }
108
109    /// Returns whether this path has no nested paths.
110    ///
111    /// # Examples
112    /// ```
113    /// # use tnc::contractionpath::ContractionPath;
114    /// # use tnc::path;
115    /// let simple_path = path![(0, 1), (0, 2), (0, 3)];
116    /// assert!(simple_path.is_simple());
117    /// let nested_path = path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)];
118    /// assert!(!nested_path.is_simple());
119    /// ```
120    #[inline]
121    pub fn is_simple(&self) -> bool {
122        self.nested.is_empty()
123    }
124
125    /// Converts this path to its toplevel component.
126    ///
127    /// # Panics
128    /// - Panics when this path has nested components
129    ///
130    /// # Examples
131    /// ```
132    /// # use tnc::contractionpath::ContractionPath;
133    /// # use tnc::path;
134    /// let contractions = vec![(0, 1), (0, 2), (0, 3)];
135    /// let simple_path = ContractionPath::simple(contractions.clone());
136    /// assert_eq!(simple_path.into_simple(), contractions);
137    /// ```
138    #[inline]
139    pub fn into_simple(self) -> SimplePath {
140        assert!(self.is_simple());
141        self.toplevel
142    }
143}
144
145/// Macro to create (nested) contraction paths, assuming the left tensor is replaced
146/// in each contraction.
147///
148/// For instance, `path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)]` creates a nested
149/// contraction path that
150/// - recursively contracts the composite tensor 2 with the contraction path `[(0, 2), (0, 1)]`
151/// - contracts tensors 0 and 1, replacing tensor 0 with the result
152/// - contracts tensors 0 and (now contracted) tensor 2, replacing tensor 0 with the result
153#[macro_export]
154macro_rules! path {
155    [] => {
156        $crate::contractionpath::ContractionPath::default()
157    };
158    [$( ($t0:expr, $t1:expr) ),*] => {
159        $crate::contractionpath::ContractionPath::simple(vec![$( ($t0, $t1) ),*])
160    };
161    [ { $( ( $index:expr, [ $( $tok:tt )* ] ) ),* $(,)? } $(, ($t0:expr, $t1:expr) )* ] => {
162        $crate::contractionpath::ContractionPath::nested(
163            vec![ $( ($index, path![ $( $tok )* ]) ),* ],
164            vec![ $( ($t0, $t1) ),* ]
165        )
166    };
167}
168
169/// The contraction ordering labels [`Tensor`] objects from each possible contraction with a
170/// unique identifier in SSA format. As only a subset of these [`Tensor`] objects are seen in
171/// a contraction path, the tensors in the optimal path search are not sequential. This converts
172/// the output to strictly obey an SSA format.
173///
174/// # Arguments
175/// * `path` - Output path as `&[(usize, usize, usize)]` after an `find_path` call.
176/// * `n` - Number of initial input tensors.
177///
178/// # Returns
179/// Identical path using SSA format
180fn ssa_ordering(path: &[(usize, usize, usize)], mut n: usize) -> ContractionPath {
181    let mut ssa_path = Vec::with_capacity(path.len());
182    let mut hs = FxHashMap::with_capacity(path.len());
183    let path_len = n;
184    for (u1, u2, u3) in path {
185        let t1 = if *u1 >= path_len { hs[u1] } else { *u1 };
186        let t2 = if *u2 >= path_len { hs[u2] } else { *u2 };
187        hs.entry(*u3).or_insert(n);
188        n += 1;
189        ssa_path.push((t1, t2));
190    }
191    ContractionPath::simple(ssa_path)
192}
193
194/// Accepts a contraction `path` that is in SSA format and returns a contraction path
195/// assuming that all contracted tensors replace the left input tensor and no tensor
196/// is popped.
197pub(super) fn ssa_replace_ordering(path: &ContractionPath) -> ContractionPath {
198    let nested = path
199        .nested
200        .iter()
201        .map(|(index, local_path)| (*index, ssa_replace_ordering(local_path)))
202        .collect();
203
204    let mut hs = FxHashMap::with_capacity(path.len());
205    let mut toplevel = Vec::with_capacity(path.len());
206    for (n, (t0, t1)) in (path.len() + 1..).zip(&path.toplevel) {
207        let new_t0 = *hs.get(t0).unwrap_or(t0);
208        let new_t1 = *hs.get(t1).unwrap_or(t1);
209
210        hs.insert_new(n, new_t0);
211        toplevel.push((new_t0, new_t1));
212    }
213
214    ContractionPath { nested, toplevel }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_path_simple_macro() {
223        assert_eq!(
224            path![{ (2, [(1, 2), (1, 3)]) }, (0, 1)],
225            ContractionPath {
226                nested: FxHashMap::from_iter([(2, ContractionPath::simple(vec![(1, 2), (1, 3)]))]),
227                toplevel: vec![(0, 1)]
228            }
229        );
230    }
231
232    #[test]
233    fn test_path_macro() {
234        assert_eq!(
235            path![
236                {
237                (2, [(1, 2), (1, 3)]),
238                (4, [{(2, [(1, 2), (1, 3)])}, (1, 3)]),
239                (5, [(1, 2), (1, 3)]),
240                (3, [(4, 1), (3, 4), (3, 5)]),
241                },
242                (0, 1),
243                (0, 2),
244                (0, 3)
245            ],
246            ContractionPath {
247                nested: FxHashMap::from_iter([
248                    (2, ContractionPath::simple(vec![(1, 2), (1, 3)])),
249                    (3, ContractionPath::simple(vec![(4, 1), (3, 4), (3, 5)])),
250                    (
251                        4,
252                        ContractionPath {
253                            nested: FxHashMap::from_iter([(
254                                2,
255                                ContractionPath::simple(vec![(1, 2), (1, 3)])
256                            )]),
257                            toplevel: vec![(1, 3)]
258                        }
259                    ),
260                    (5, ContractionPath::simple(vec![(1, 2), (1, 3)]))
261                ]),
262                toplevel: vec![(0, 1), (0, 2), (0, 3)]
263            }
264        );
265    }
266
267    #[test]
268    fn test_ssa_ordering() {
269        let path = vec![
270            (0, 3, 15),
271            (1, 2, 44),
272            (6, 4, 8),
273            (5, 15, 22),
274            (8, 44, 12),
275            (12, 22, 99),
276        ];
277        let new_path = ssa_ordering(&path, 7);
278
279        assert_eq!(
280            new_path,
281            path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)]
282        );
283    }
284
285    #[test]
286    fn test_ssa_replace_ordering() {
287        let path = path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)];
288        let new_path = ssa_replace_ordering(&path);
289
290        assert_eq!(
291            new_path,
292            path![(0, 3), (1, 2), (6, 4), (5, 0), (6, 1), (6, 5)]
293        );
294    }
295
296    #[test]
297    fn test_ssa_replace_ordering_nested() {
298        let path = path![
299            {
300            (1, [(2, 1), (0, 3)]),
301            (6, [(0, 2), (1, 3), (4, 5)])
302            },
303            (0, 3),
304            (1, 2),
305            (6, 4),
306            (5, 7),
307            (9, 8),
308            (11, 10)
309        ];
310
311        let new_path = ssa_replace_ordering(&path);
312
313        assert_eq!(
314            new_path,
315            path![
316                {
317                (1, [(2, 1), (0, 2)]),
318                (6, [(0, 2), (1, 3), (0, 1)])
319                },
320                (0, 3),
321                (1, 2),
322                (6, 4),
323                (5, 0),
324                (6, 1),
325                (6, 5)
326            ]
327        );
328    }
329}