tnc/contractionpath.rs
1//! Find and evaluate the cost of contraction paths. Also contains functionality for
2//! rebalancing of partitioned tensor networks.
3
4use rustc_hash::FxHashMap;
5use serde::{Deserialize, Serialize};
6
7use crate::tensornetwork::tensor::TensorIndex;
8use crate::utils::traits::{HashMapInsertNew, WithCapacity};
9
10mod candidates;
11pub mod communication_schemes;
12pub mod contraction_cost;
13pub mod contraction_tree;
14pub mod paths;
15pub mod repartitioning;
16
17/// A simple, flat contraction path. If you only need a reference, prefer
18/// [`SimplePathRef`].
19pub type SimplePath = Vec<(TensorIndex, TensorIndex)>;
20
21/// Reference to a [`SimplePath`].
22pub type SimplePathRef<'a> = &'a [(TensorIndex, TensorIndex)];
23
24/// A (possibly nested) contraction path.
25///
26/// It specifies the overall contraction path to contract a tensor network, but also
27/// allows to specify additional contraction paths for each tensor, in order to deal
28/// with composite tensors that have to be contracted first.
29#[derive(Debug, Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
30pub struct ContractionPath {
31 /// Nested contraction paths for composite tensors.
32 pub nested: FxHashMap<TensorIndex, ContractionPath>,
33 /// The top-level contraction path for the tensor network itself.
34 pub toplevel: SimplePath,
35}
36
37impl ContractionPath {
38 /// Creates a contraction path with nested paths.
39 #[inline]
40 pub fn nested(
41 nested: Vec<(TensorIndex, ContractionPath)>,
42 toplevel: Vec<(TensorIndex, TensorIndex)>,
43 ) -> Self {
44 Self {
45 nested: FxHashMap::from_iter(nested),
46 toplevel,
47 }
48 }
49
50 /// Creates a plain contraction path without nested paths.
51 ///
52 /// # Examples
53 /// ```
54 /// # use tnc::contractionpath::{ContractionPath, SimplePath};
55 /// let path: SimplePath = vec![(0, 1), (0, 2), (0, 3)];
56 /// let contraction_path = ContractionPath::simple(path.clone());
57 /// assert!(contraction_path.is_simple());
58 /// assert_eq!(contraction_path.toplevel, path);
59 /// ```
60 #[inline]
61 pub fn simple(path: SimplePath) -> Self {
62 Self {
63 nested: FxHashMap::default(),
64 toplevel: path,
65 }
66 }
67
68 /// Creates a contraction path from a single contraction of two tensors.
69 ///
70 /// # Examples
71 /// ```
72 /// # use tnc::contractionpath::ContractionPath;
73 /// let contraction_path = ContractionPath::single(0, 1);
74 /// assert!(contraction_path.is_simple());
75 /// assert_eq!(contraction_path.toplevel, vec![(0, 1)]);
76 /// ```
77 #[inline]
78 pub fn single(a: TensorIndex, b: TensorIndex) -> Self {
79 Self::simple(vec![(a, b)])
80 }
81
82 /// The length of the contraction path, that is, the number of top-level
83 /// contractions.
84 ///
85 /// # Examples
86 /// ```
87 /// # use tnc::contractionpath::ContractionPath;
88 /// let contraction_path = ContractionPath::simple(vec![(0, 1), (0, 2), (0, 3)]);
89 /// assert_eq!(contraction_path.len(), 3);
90 /// ```
91 #[inline]
92 pub fn len(&self) -> usize {
93 self.toplevel.len()
94 }
95
96 /// Whether there are any top-level contractions in this contraction path.
97 ///
98 /// # Examples
99 /// ```
100 /// # use tnc::contractionpath::ContractionPath;
101 /// assert!(ContractionPath::default().is_empty());
102 /// assert!(!ContractionPath::simple(vec![(0, 1)]).is_empty());
103 /// ```
104 #[inline]
105 pub fn is_empty(&self) -> bool {
106 self.toplevel.is_empty()
107 }
108
109 /// Returns whether this path has no nested paths.
110 ///
111 /// # Examples
112 /// ```
113 /// # use tnc::contractionpath::ContractionPath;
114 /// # use tnc::path;
115 /// let simple_path = path![(0, 1), (0, 2), (0, 3)];
116 /// assert!(simple_path.is_simple());
117 /// let nested_path = path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)];
118 /// assert!(!nested_path.is_simple());
119 /// ```
120 #[inline]
121 pub fn is_simple(&self) -> bool {
122 self.nested.is_empty()
123 }
124
125 /// Converts this path to its toplevel component.
126 ///
127 /// # Panics
128 /// - Panics when this path has nested components
129 ///
130 /// # Examples
131 /// ```
132 /// # use tnc::contractionpath::ContractionPath;
133 /// # use tnc::path;
134 /// let contractions = vec![(0, 1), (0, 2), (0, 3)];
135 /// let simple_path = ContractionPath::simple(contractions.clone());
136 /// assert_eq!(simple_path.into_simple(), contractions);
137 /// ```
138 #[inline]
139 pub fn into_simple(self) -> SimplePath {
140 assert!(self.is_simple());
141 self.toplevel
142 }
143}
144
145/// Macro to create (nested) contraction paths, assuming the left tensor is replaced
146/// in each contraction.
147///
148/// For instance, `path![{(2, [(0, 2), (0, 1)])}, (0, 1), (0, 2)]` creates a nested
149/// contraction path that
150/// - recursively contracts the composite tensor 2 with the contraction path `[(0, 2), (0, 1)]`
151/// - contracts tensors 0 and 1, replacing tensor 0 with the result
152/// - contracts tensors 0 and (now contracted) tensor 2, replacing tensor 0 with the result
153#[macro_export]
154macro_rules! path {
155 [] => {
156 $crate::contractionpath::ContractionPath::default()
157 };
158 [$( ($t0:expr, $t1:expr) ),*] => {
159 $crate::contractionpath::ContractionPath::simple(vec![$( ($t0, $t1) ),*])
160 };
161 [ { $( ( $index:expr, [ $( $tok:tt )* ] ) ),* $(,)? } $(, ($t0:expr, $t1:expr) )* ] => {
162 $crate::contractionpath::ContractionPath::nested(
163 vec![ $( ($index, path![ $( $tok )* ]) ),* ],
164 vec![ $( ($t0, $t1) ),* ]
165 )
166 };
167}
168
169/// The contraction ordering labels [`Tensor`] objects from each possible contraction with a
170/// unique identifier in SSA format. As only a subset of these [`Tensor`] objects are seen in
171/// a contraction path, the tensors in the optimal path search are not sequential. This converts
172/// the output to strictly obey an SSA format.
173///
174/// # Arguments
175/// * `path` - Output path as `&[(usize, usize, usize)]` after an `find_path` call.
176/// * `n` - Number of initial input tensors.
177///
178/// # Returns
179/// Identical path using SSA format
180fn ssa_ordering(path: &[(usize, usize, usize)], mut n: usize) -> ContractionPath {
181 let mut ssa_path = Vec::with_capacity(path.len());
182 let mut hs = FxHashMap::with_capacity(path.len());
183 let path_len = n;
184 for (u1, u2, u3) in path {
185 let t1 = if *u1 >= path_len { hs[u1] } else { *u1 };
186 let t2 = if *u2 >= path_len { hs[u2] } else { *u2 };
187 hs.entry(*u3).or_insert(n);
188 n += 1;
189 ssa_path.push((t1, t2));
190 }
191 ContractionPath::simple(ssa_path)
192}
193
194/// Accepts a contraction `path` that is in SSA format and returns a contraction path
195/// assuming that all contracted tensors replace the left input tensor and no tensor
196/// is popped.
197pub(super) fn ssa_replace_ordering(path: &ContractionPath) -> ContractionPath {
198 let nested = path
199 .nested
200 .iter()
201 .map(|(index, local_path)| (*index, ssa_replace_ordering(local_path)))
202 .collect();
203
204 let mut hs = FxHashMap::with_capacity(path.len());
205 let mut toplevel = Vec::with_capacity(path.len());
206 let mut n = path.len() + 1;
207 for (t0, t1) in &path.toplevel {
208 let new_t0 = *hs.get(t0).unwrap_or(t0);
209 let new_t1 = *hs.get(t1).unwrap_or(t1);
210
211 hs.insert_new(n, new_t0);
212 toplevel.push((new_t0, new_t1));
213 n += 1;
214 }
215
216 ContractionPath { nested, toplevel }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_path_simple_macro() {
225 assert_eq!(
226 path![{ (2, [(1, 2), (1, 3)]) }, (0, 1)],
227 ContractionPath {
228 nested: FxHashMap::from_iter([(2, ContractionPath::simple(vec![(1, 2), (1, 3)]))]),
229 toplevel: vec![(0, 1)]
230 }
231 );
232 }
233
234 #[test]
235 fn test_path_macro() {
236 assert_eq!(
237 path![
238 {
239 (2, [(1, 2), (1, 3)]),
240 (4, [{(2, [(1, 2), (1, 3)])}, (1, 3)]),
241 (5, [(1, 2), (1, 3)]),
242 (3, [(4, 1), (3, 4), (3, 5)]),
243 },
244 (0, 1),
245 (0, 2),
246 (0, 3)
247 ],
248 ContractionPath {
249 nested: FxHashMap::from_iter([
250 (2, ContractionPath::simple(vec![(1, 2), (1, 3)])),
251 (3, ContractionPath::simple(vec![(4, 1), (3, 4), (3, 5)])),
252 (
253 4,
254 ContractionPath {
255 nested: FxHashMap::from_iter([(
256 2,
257 ContractionPath::simple(vec![(1, 2), (1, 3)])
258 )]),
259 toplevel: vec![(1, 3)]
260 }
261 ),
262 (5, ContractionPath::simple(vec![(1, 2), (1, 3)]))
263 ]),
264 toplevel: vec![(0, 1), (0, 2), (0, 3)]
265 }
266 );
267 }
268
269 #[test]
270 fn test_ssa_ordering() {
271 let path = vec![
272 (0, 3, 15),
273 (1, 2, 44),
274 (6, 4, 8),
275 (5, 15, 22),
276 (8, 44, 12),
277 (12, 22, 99),
278 ];
279 let new_path = ssa_ordering(&path, 7);
280
281 assert_eq!(
282 new_path,
283 path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)]
284 );
285 }
286
287 #[test]
288 fn test_ssa_replace_ordering() {
289 let path = path![(0, 3), (1, 2), (6, 4), (5, 7), (9, 8), (11, 10)];
290 let new_path = ssa_replace_ordering(&path);
291
292 assert_eq!(
293 new_path,
294 path![(0, 3), (1, 2), (6, 4), (5, 0), (6, 1), (6, 5)]
295 );
296 }
297
298 #[test]
299 fn test_ssa_replace_ordering_nested() {
300 let path = path![
301 {
302 (1, [(2, 1), (0, 3)]),
303 (6, [(0, 2), (1, 3), (4, 5)])
304 },
305 (0, 3),
306 (1, 2),
307 (6, 4),
308 (5, 7),
309 (9, 8),
310 (11, 10)
311 ];
312
313 let new_path = ssa_replace_ordering(&path);
314
315 assert_eq!(
316 new_path,
317 path![
318 {
319 (1, [(2, 1), (0, 2)]),
320 (6, [(0, 2), (1, 3), (0, 1)])
321 },
322 (0, 3),
323 (1, 2),
324 (6, 4),
325 (5, 0),
326 (6, 1),
327 (6, 5)
328 ]
329 );
330 }
331}