tnc/contractionpath.rs
1use rustc_hash::FxHashMap;
2use serde::{Deserialize, Serialize};
3
4use crate::tensornetwork::tensor::TensorIndex;
5use crate::utils::traits::{HashMapInsertNew, WithCapacity};
6
7mod candidates;
8pub mod communication_schemes;
9pub mod contraction_cost;
10pub mod contraction_tree;
11pub mod paths;
12pub mod repartitioning;
13
14/// A simple, flat contraction path. If you only need a reference, prefer
15/// [`SimplePathRef`].
16pub type SimplePath = Vec<(TensorIndex, TensorIndex)>;
17
18/// Reference to a [`SimplePath`].
19pub type SimplePathRef<'a> = &'a [(TensorIndex, TensorIndex)];
20
21/// A (possibly nested) contraction path.
22///
23/// It specifies the overall contraction path to contract a tensor network, but also
24/// allows to specify additional contraction paths for each tensor, in order to deal
25/// with composite tensors that have to be contracted first.
26#[derive(Debug, Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
27pub struct ContractionPath {
28 /// Nested contraction paths for composite tensors.
29 pub nested: FxHashMap<TensorIndex, ContractionPath>,
30 /// The top-level contraction path for the tensor network itself.
31 pub toplevel: SimplePath,
32}
33
34impl ContractionPath {
35 /// Creates a contraction path with nested paths.
36 #[inline]
37 pub fn nested(
38 nested: Vec<(TensorIndex, ContractionPath)>,
39 toplevel: Vec<(TensorIndex, TensorIndex)>,
40 ) -> Self {
41 Self {
42 nested: FxHashMap::from_iter(nested),
43 toplevel,
44 }
45 }
46
47 /// Creates a plain contraction path without nested paths.
48 ///
49 /// # Examples
50 /// ```
51 /// # use tnc::contractionpath::{ContractionPath, SimplePath};
52 /// let path: SimplePath = vec![(0, 1), (0, 2), (0, 3)];
53 /// let contraction_path = ContractionPath::simple(path.clone());
54 /// assert!(contraction_path.is_simple());
55 /// assert_eq!(contraction_path.toplevel, path);
56 /// ```
57 #[inline]
58 pub fn simple(path: SimplePath) -> Self {
59 Self {
60 nested: FxHashMap::default(),
61 toplevel: path,
62 }
63 }
64
65 /// Creates a contraction path from a single contraction of two tensors.
66 ///
67 /// # Examples
68 /// ```
69 /// # use tnc::contractionpath::ContractionPath;
70 /// let contraction_path = ContractionPath::single(0, 1);
71 /// assert!(contraction_path.is_simple());
72 /// assert_eq!(contraction_path.toplevel, vec![(0, 1)]);
73 /// ```
74 #[inline]
75 pub fn single(a: TensorIndex, b: TensorIndex) -> Self {
76 Self::simple(vec![(a, b)])
77 }
78
79 /// The length of the contraction path, that is, the number of top-level
80 /// contractions.
81 ///
82 /// # Examples
83 /// ```
84 /// # use tnc::contractionpath::ContractionPath;
85 /// let contraction_path = ContractionPath::simple(vec![(0, 1), (0, 2), (0, 3)]);
86 /// assert_eq!(contraction_path.len(), 3);
87 /// ```
88 #[inline]
89 pub fn len(&self) -> usize {
90 self.toplevel.len()
91 }
92
93 /// Whether there are any top-level contractions in this contraction path.
94 ///
95 /// # Examples
96 /// ```
97 /// # use tnc::contractionpath::ContractionPath;
98 /// assert!(ContractionPath::default().is_empty());
99 /// assert!(!ContractionPath::simple(vec![(0, 1)]).is_empty());
100 /// ```
101 #[inline]
102 pub fn is_empty(&self) -> bool {
103 self.toplevel.is_empty()
104 }
105
106 /// Returns whether this path has no nested paths.
107 ///
108 /// # Examples
109 /// ```
110 /// # use tnc::contractionpath::ContractionPath;
111 /// # use tnc::path;
112 /// let simple_path = path![(0, 1), (0, 2), (0, 3)];
113 /// assert!(simple_path.is_simple());
114 /// let nested_path = path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)];
115 /// assert!(!nested_path.is_simple());
116 /// ```
117 #[inline]
118 pub fn is_simple(&self) -> bool {
119 self.nested.is_empty()
120 }
121
122 /// Converts this path to its toplevel component.
123 ///
124 /// # Panics
125 /// - Panics when this path has nested components
126 ///
127 /// # Examples
128 /// ```
129 /// # use tnc::contractionpath::ContractionPath;
130 /// # use tnc::path;
131 /// let contractions = vec![(0, 1), (0, 2), (0, 3)];
132 /// let simple_path = ContractionPath::simple(contractions.clone());
133 /// assert_eq!(simple_path.into_simple(), contractions);
134 /// ```
135 #[inline]
136 pub fn into_simple(self) -> SimplePath {
137 assert!(self.is_simple());
138 self.toplevel
139 }
140}
141
142/// Macro to create (nested) contraction paths, assuming the left tensor is replaced
143/// in each contraction.
144///
145/// For instance, `path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)]` creates a nested
146/// contraction path that
147/// - recursively contracts the composite tensor 2 with the contraction path `[(0, 2), (0, 1)]`
148/// - contracts tensors 0 and 1, replacing tensor 0 with the result
149/// - contracts tensors 0 and (now contracted) tensor 2, replacing tensor 0 with the result
150#[macro_export]
151macro_rules! path {
152 [] => {
153 $crate::contractionpath::ContractionPath::default()
154 };
155 [$( ($t0:expr, $t1:expr) ),*] => {
156 $crate::contractionpath::ContractionPath::simple(vec![$( ($t0, $t1) ),*])
157 };
158 [ { $( ( $index:expr, [ $( $tok:tt )* ] ) ),* $(,)? } $(, ($t0:expr, $t1:expr) )* ] => {
159 $crate::contractionpath::ContractionPath::nested(
160 vec![ $( ($index, path![ $( $tok )* ]) ),* ],
161 vec![ $( ($t0, $t1) ),* ]
162 )
163 };
164}
165
166/// The contraction ordering labels [`Tensor`] objects from each possible contraction with a
167/// unique identifier in SSA format. As only a subset of these [`Tensor`] objects are seen in
168/// a contraction path, the tensors in the optimal path search are not sequential. This converts
169/// the output to strictly obey an SSA format.
170///
171/// # Arguments
172/// * `path` - Output path as `&[(usize, usize, usize)]` after an `find_path` call.
173/// * `n` - Number of initial input tensors.
174///
175/// # Returns
176/// Identical path using SSA format
177fn ssa_ordering(path: &[(usize, usize, usize)], mut n: usize) -> ContractionPath {
178 let mut ssa_path = Vec::with_capacity(path.len());
179 let mut hs = FxHashMap::with_capacity(path.len());
180 let path_len = n;
181 for (u1, u2, u3) in path {
182 let t1 = if *u1 >= path_len { hs[u1] } else { *u1 };
183 let t2 = if *u2 >= path_len { hs[u2] } else { *u2 };
184 hs.entry(*u3).or_insert(n);
185 n += 1;
186 ssa_path.push((t1, t2));
187 }
188 ContractionPath::simple(ssa_path)
189}
190
191/// Accepts a contraction `path` that is in SSA format and returns a contraction path
192/// assuming that all contracted tensors replace the left input tensor and no tensor
193/// is popped.
194pub(super) fn ssa_replace_ordering(path: &ContractionPath) -> ContractionPath {
195 let nested = path
196 .nested
197 .iter()
198 .map(|(index, local_path)| (*index, ssa_replace_ordering(local_path)))
199 .collect();
200
201 let mut hs = FxHashMap::with_capacity(path.len());
202 let mut toplevel = Vec::with_capacity(path.len());
203 let mut n = path.len() + 1;
204 for (t0, t1) in &path.toplevel {
205 let new_t0 = *hs.get(t0).unwrap_or(t0);
206 let new_t1 = *hs.get(t1).unwrap_or(t1);
207
208 hs.insert_new(n, new_t0);
209 toplevel.push((new_t0, new_t1));
210 n += 1;
211 }
212
213 ContractionPath { nested, toplevel }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_path_simple_macro() {
222 assert_eq!(
223 path![{ (2, [(1, 2), (1, 3)]) }, (0, 1)],
224 ContractionPath {
225 nested: FxHashMap::from_iter([(2, ContractionPath::simple(vec![(1, 2), (1, 3)]))]),
226 toplevel: vec![(0, 1)]
227 }
228 );
229 }
230
231 #[test]
232 fn test_path_macro() {
233 assert_eq!(
234 path![
235 {
236 (2, [(1, 2), (1, 3)]),
237 (4, [{(2, [(1, 2), (1, 3)])}, (1, 3)]),
238 (5, [(1, 2), (1, 3)]),
239 (3, [(4, 1), (3, 4), (3, 5)]),
240 },
241 (0, 1),
242 (0, 2),
243 (0, 3)
244 ],
245 ContractionPath {
246 nested: FxHashMap::from_iter([
247 (2, ContractionPath::simple(vec![(1, 2), (1, 3)])),
248 (3, ContractionPath::simple(vec![(4, 1), (3, 4), (3, 5)])),
249 (
250 4,
251 ContractionPath {
252 nested: FxHashMap::from_iter([(
253 2,
254 ContractionPath::simple(vec![(1, 2), (1, 3)])
255 )]),
256 toplevel: vec![(1, 3)]
257 }
258 ),
259 (5, ContractionPath::simple(vec![(1, 2), (1, 3)]))
260 ]),
261 toplevel: vec![(0, 1), (0, 2), (0, 3)]
262 }
263 );
264 }
265
266 #[test]
267 fn test_ssa_ordering() {
268 let path = vec![
269 (0, 3, 15),
270 (1, 2, 44),
271 (6, 4, 8),
272 (5, 15, 22),
273 (8, 44, 12),
274 (12, 22, 99),
275 ];
276 let new_path = ssa_ordering(&path, 7);
277
278 assert_eq!(
279 new_path,
280 path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)]
281 );
282 }
283
284 #[test]
285 fn test_ssa_replace_ordering() {
286 let path = path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)];
287 let new_path = ssa_replace_ordering(&path);
288
289 assert_eq!(
290 new_path,
291 path![(0, 3), (1, 2), (6, 4), (5, 0), (6, 1), (6, 5)]
292 );
293 }
294
295 #[test]
296 fn test_ssa_replace_ordering_nested() {
297 let path = path![
298 {
299 (1, [(2, 1), (0, 3)]),
300 (6, [(0, 2), (1, 3), (4, 5)])
301 },
302 (0, 3),
303 (1, 2),
304 (6, 4),
305 (5, 7),
306 (9, 8),
307 (11, 10)
308 ];
309
310 let new_path = ssa_replace_ordering(&path);
311
312 assert_eq!(
313 new_path,
314 path![
315 {
316 (1, [(2, 1), (0, 2)]),
317 (6, [(0, 2), (1, 3), (0, 1)])
318 },
319 (0, 3),
320 (1, 2),
321 (6, 4),
322 (5, 0),
323 (6, 1),
324 (6, 5)
325 ]
326 );
327 }
328}