tnc/contractionpath/
contraction_cost.rs

1//! Different methods to compute the computational and memory cost of contraction
2//! paths.
3
4use num_complex::Complex64;
5
6use crate::{
7    contractionpath::{ContractionPath, SimplePathRef},
8    tensornetwork::tensor::{EdgeIndex, Tensor},
9};
10
11/// Returns Schroedinger contraction time complexity of contracting two [`Tensor`]
12/// objects. Considers cost of complex operations.
13///
14/// # Examples
15/// ```
16/// # use tnc::tensornetwork::tensor::Tensor;
17/// # use tnc::contractionpath::contraction_cost::contract_cost_tensors;
18/// # use rustc_hash::FxHashMap;
19/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
20/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims);
21/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
22/// // result = [0, 1, 2, 3, 4] // cost of (9-1)*54*5005 = 350350;
23/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
24/// assert_eq!(contract_cost_tensors(&tn.tensor(0), &tn.tensor(1)), 350350.);
25/// ```
26pub fn contract_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
27    let final_dims = t_1 ^ t_2;
28    let shared_dims = t_1 & t_2;
29
30    let single_loop_cost = shared_dims.size();
31    (single_loop_cost - 1f64).mul_add(2f64, single_loop_cost * 6f64) * final_dims.size()
32}
33
34/// Returns Schroedinger contraction time complexity of contracting two [`Tensor`]
35/// objects. Naive op cost, does not consider costs of multiplication.
36///
37/// # Examples
38/// ```
39/// # use tnc::tensornetwork::tensor::Tensor;
40/// # use tnc::contractionpath::contraction_cost::contract_op_cost_tensors;
41/// # use rustc_hash::FxHashMap;
42/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
43/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims);
44/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
45/// // result = [0, 1, 2, 3, 4] // cost of 5*7*9*11*13 = 45045;
46/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
47/// assert_eq!(contract_op_cost_tensors(&tn.tensor(0), &tn.tensor(1)), 45045.);
48/// ```
49#[inline]
50pub fn contract_op_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
51    let all_dims = t_1 | t_2;
52    all_dims.size()
53}
54
55/// Returns Schroedinger contraction space complexity of contracting two [`Tensor`]
56/// objects.
57///
58/// # Examples
59/// ```
60/// # use tnc::tensornetwork::tensor::Tensor;
61/// # use tnc::contractionpath::contraction_cost::contract_size_tensors;
62/// # use rustc_hash::FxHashMap;
63/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
64/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims); // 315 entries
65/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims); // 1287 entries
66/// // result = [0, 1, 3, 4] //  5005 entries -> total 6607 entries
67/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
68/// assert_eq!(contract_size_tensors(&tn.tensor(0), &tn.tensor(1)), 6607.);
69/// ```
70#[inline]
71pub fn contract_size_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
72    let diff = t_1 ^ t_2;
73    diff.size() + t_1.size() + t_2.size()
74}
75
76/// Returns a rather exact estimate of the memory requirements for
77/// contracting tensors `i` and `j`.
78///
79/// This takes into account if tensors need to be transposed (which doubles the
80/// required memory). It does not include memory of additional data like shape,
81/// bonddims, legs, etc..
82///
83/// # Examples
84/// ```
85/// # use tnc::tensornetwork::tensor::Tensor;
86/// # use tnc::contractionpath::contraction_cost::contract_size_tensors_exact;
87/// # use rustc_hash::FxHashMap;
88/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11)]);
89/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims); // requires 5040 bytes
90/// let tensor2 = Tensor::new_from_map(vec![3, 2], &bond_dims);    // requires 1584 bytes
91/// // result = [0, 1, 3], requires 6160 bytes
92/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
93/// assert_eq!(contract_size_tensors_exact(&tn.tensor(0), &tn.tensor(1)), 12784.);
94/// ```
95pub fn contract_size_tensors_exact(i: &Tensor, j: &Tensor) -> f64 {
96    /// Checks if `prefix` is a prefix of `list`.
97    #[inline]
98    fn is_prefix(prefix: &[EdgeIndex], list: &[EdgeIndex]) -> bool {
99        if prefix.len() > list.len() {
100            return false;
101        }
102        list.iter().zip(prefix.iter()).all(|(a, b)| a == b)
103    }
104
105    /// Checks if `suffix` is a suffix of `list`.
106    #[inline]
107    fn is_suffix(suffix: &[EdgeIndex], list: &[EdgeIndex]) -> bool {
108        if suffix.len() > list.len() {
109            return false;
110        }
111        list.iter()
112            .rev()
113            .zip(suffix.iter().rev())
114            .all(|(a, b)| a == b)
115    }
116
117    let ij = i ^ j;
118    let contracted_legs = i & j;
119    let i_needs_transpose = !is_suffix(contracted_legs.legs(), i.legs());
120    let j_needs_transpose = !is_prefix(contracted_legs.legs(), j.legs());
121
122    let i_size = i.size();
123    let j_size = j.size();
124    let ij_size = ij.size();
125
126    let elements = match (i_needs_transpose, j_needs_transpose) {
127        (true, true) => (2.0 * i_size + j_size)
128            .max(i_size + 2.0 * j_size)
129            .max(i_size + j_size + ij_size),
130        (true, false) => (2.0 * i_size + j_size).max(i_size + j_size + ij_size),
131        (false, true) => (i_size + 2.0 * j_size).max(i_size + j_size + ij_size),
132        (false, false) => i_size + j_size + ij_size,
133    };
134
135    elements * std::mem::size_of::<Complex64>() as f64
136}
137
138/// Returns Schroedinger contraction time and space complexity of fully contracting
139/// the input tensors.
140///
141/// # Arguments
142/// * `inputs` - Tensors to contract
143/// * `contract_path`  - Contraction order (replace path)
144/// * `only_count_ops` - If `true`, ignores cost of complex multiplication and addition and only counts number of operations
145#[inline]
146pub fn contract_path_cost(
147    inputs: &[Tensor],
148    contract_path: &ContractionPath,
149    only_count_ops: bool,
150) -> (f64, f64) {
151    let cost_function = if only_count_ops {
152        contract_op_cost_tensors
153    } else {
154        contract_cost_tensors
155    };
156    contract_path_custom_cost(inputs, contract_path, cost_function, contract_size_tensors)
157}
158
159/// Returns Schroedinger contraction time and space complexity of fully contracting
160/// the input tensors.
161///
162/// # Arguments
163/// * `inputs` - Tensors to contract
164/// * `contract_path`  - Contraction order (replace path)
165/// * `cost_function` - Function to calculate cost of contracting two tensors
166fn contract_path_custom_cost(
167    inputs: &[Tensor],
168    contract_path: &ContractionPath,
169    cost_function: fn(&Tensor, &Tensor) -> f64,
170    size_function: fn(&Tensor, &Tensor) -> f64,
171) -> (f64, f64) {
172    let mut op_cost = 0f64;
173    let mut mem_cost = 0f64;
174    let mut inputs = inputs.to_vec();
175
176    for (i, path) in &contract_path.nested {
177        let costs =
178            contract_path_custom_cost(inputs[*i].tensors(), path, cost_function, size_function);
179        op_cost += costs.0;
180        mem_cost = mem_cost.max(costs.1);
181        inputs[*i] = inputs[*i].external_tensor();
182    }
183
184    for &(i, j) in &contract_path.toplevel {
185        op_cost += cost_function(&inputs[i], &inputs[j]);
186        let ij = &inputs[i] ^ &inputs[j];
187        let new_mem_cost = size_function(&inputs[i], &inputs[j]);
188        mem_cost = mem_cost.max(new_mem_cost);
189        inputs[i] = ij;
190    }
191
192    (op_cost, mem_cost)
193}
194
195/// Returns Schroedinger contraction time complexity using the critical path metric
196/// and using the sum metric. Additionally returns the space complexity.
197#[inline]
198pub fn communication_path_op_costs(
199    inputs: &[Tensor],
200    contract_path: SimplePathRef,
201    only_count_ops: bool,
202    tensor_cost: Option<&[f64]>,
203) -> ((f64, f64), f64) {
204    let (parallel_cost, _) =
205        communication_path_cost(inputs, contract_path, only_count_ops, true, tensor_cost);
206    let (serial_cost, mem_cost) =
207        communication_path_cost(inputs, contract_path, only_count_ops, false, tensor_cost);
208    ((parallel_cost, serial_cost), mem_cost)
209}
210
211/// Returns Schroedinger contraction time and space complexity of fully contracting
212/// the input tensors assuming all operations occur in parallel.
213///
214/// # Arguments
215/// * `inputs` - Tensors to contract
216/// * `contract_path`  - Contraction order (replace path)
217/// * `only_count_ops` - If `true`, ignores cost of complex multiplication and addition and only counts number of operations
218/// * `only_circital_path` - If `true`, only counts the cost along the critical path, otherwise the sum of all costs
219/// * `tensor_costs` - Initial cost for each tensor
220pub fn communication_path_cost(
221    inputs: &[Tensor],
222    contract_path: SimplePathRef,
223    only_count_ops: bool,
224    only_critical_path: bool,
225    tensor_cost: Option<&[f64]>,
226) -> (f64, f64) {
227    let cost_function = if only_count_ops {
228        contract_op_cost_tensors
229    } else {
230        contract_cost_tensors
231    };
232    let tensor_cost = if let Some(tensor_cost) = tensor_cost {
233        assert_eq!(inputs.len(), tensor_cost.len());
234        tensor_cost
235    } else {
236        &vec![0f64; inputs.len()]
237    };
238    if inputs.len() == 1 {
239        return (tensor_cost[0], tensor_cost[0]);
240    }
241
242    communication_path_custom_cost(
243        inputs,
244        contract_path,
245        cost_function,
246        only_critical_path,
247        tensor_cost,
248    )
249}
250
251/// Returns Schroedinger contraction time and space complexity of fully contracting
252/// the input tensors assuming all operations occur in parallel.
253///
254/// # Arguments
255/// * `inputs` - Tensors to contract
256/// * `contract_path`  - Contraction order (replace path)
257/// * `cost_function` - Function to calculate cost of contracting two tensors
258/// * `tensor_costs` - Initial cost for each tensor
259fn communication_path_custom_cost(
260    inputs: &[Tensor],
261    contract_path: SimplePathRef,
262    cost_function: fn(&Tensor, &Tensor) -> f64,
263    only_critical_path: bool,
264    tensor_cost: &[f64],
265) -> (f64, f64) {
266    let mut op_cost = 0f64;
267    let mut mem_cost = 0f64;
268    let mut inputs = inputs.to_vec();
269    let mut tensor_cost = tensor_cost.to_vec();
270
271    for &(i, j) in contract_path {
272        let ij = &inputs[i] ^ &inputs[j];
273        let new_mem_cost = contract_size_tensors(&inputs[i], &inputs[j]);
274        mem_cost = mem_cost.max(new_mem_cost);
275
276        op_cost = if only_critical_path {
277            cost_function(&inputs[i], &inputs[j]) + tensor_cost[i].max(tensor_cost[j])
278        } else {
279            cost_function(&inputs[i], &inputs[j]) + tensor_cost[i] + tensor_cost[j]
280        };
281        tensor_cost[i] = op_cost;
282        inputs[i] = ij;
283    }
284
285    (op_cost, mem_cost)
286}
287
288/// Computes the max memory requirements for contracting the tensor network using the
289/// given path. Uses `memory_estimator` to compute the memory required to contract
290/// two tensors.
291///
292/// Candidates for `memory_estimator` are e.g.:
293/// - [`contract_size_tensors`]
294/// - [`contract_size_tensors_exact`]
295#[inline]
296pub fn compute_memory_requirements(
297    inputs: &[Tensor],
298    contract_path: &ContractionPath,
299    memory_estimator: fn(&Tensor, &Tensor) -> f64,
300) -> f64 {
301    fn id(_: &Tensor, _: &Tensor) -> f64 {
302        0.0
303    }
304    let (_, mem) = contract_path_custom_cost(inputs, contract_path, id, memory_estimator);
305    mem
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    use rustc_hash::FxHashMap;
313
314    use crate::path;
315    use crate::tensornetwork::tensor::Tensor;
316
317    fn setup_simple() -> Tensor {
318        let bond_dims =
319            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
320        Tensor::new_composite(vec![
321            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
322            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
323            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
324        ])
325    }
326
327    fn setup_complex() -> Tensor {
328        let bond_dims = FxHashMap::from_iter([
329            (0, 5),
330            (1, 2),
331            (2, 6),
332            (3, 8),
333            (4, 1),
334            (5, 3),
335            (6, 4),
336            (7, 3),
337            (8, 2),
338            (9, 2),
339        ]);
340        let t1_tensors = vec![
341            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
342            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
343            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
344        ];
345        let t1 = Tensor::new_composite(t1_tensors);
346
347        let t2_tensors = vec![
348            Tensor::new_from_map(vec![5, 6, 8], &bond_dims),
349            Tensor::new_from_map(vec![7, 8, 9], &bond_dims),
350        ];
351        let t2 = Tensor::new_composite(t2_tensors);
352        Tensor::new_composite(vec![t1, t2])
353    }
354
355    fn setup_parallel() -> Tensor {
356        let bond_dims =
357            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
358        Tensor::new_composite(vec![
359            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
360            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
361            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
362            Tensor::new_from_map(vec![5, 6], &bond_dims),
363        ])
364    }
365
366    #[test]
367    fn test_contract_path_cost() {
368        let tn = setup_simple();
369        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], false);
370        assert_eq!(op_cost, 4540.);
371        assert_eq!(mem_cost, 538.);
372        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], false);
373        assert_eq!(op_cost, 49296.);
374        assert_eq!(mem_cost, 1176.);
375    }
376
377    #[test]
378    fn test_contract_complex_path_cost() {
379        let tn = setup_complex();
380        let (op_cost, mem_cost) = contract_path_cost(
381            tn.tensors(),
382            &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
383            false,
384        );
385        assert_eq!(op_cost, 11188.);
386        assert_eq!(mem_cost, 538.);
387    }
388
389    #[test]
390    fn test_contract_path_cost_only_ops() {
391        let tn = setup_simple();
392        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], true);
393        assert_eq!(op_cost, 600.);
394        assert_eq!(mem_cost, 538.);
395        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], true);
396        assert_eq!(op_cost, 6336.);
397        assert_eq!(mem_cost, 1176.);
398    }
399
400    #[test]
401    fn test_contract_path_complex_cost_only_ops() {
402        let tn = setup_complex();
403        let (op_cost, mem_cost) = contract_path_cost(
404            tn.tensors(),
405            &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
406            true,
407        );
408        assert_eq!(op_cost, 1464.);
409        assert_eq!(mem_cost, 538.);
410    }
411
412    #[test]
413    fn test_communication_path_cost_only_ops() {
414        let tn = setup_parallel();
415        let (op_cost, mem_cost) =
416            communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 2)], true, true, None);
417        assert_eq!(op_cost, 490.);
418        assert_eq!(mem_cost, 538.);
419    }
420
421    #[test]
422    fn test_communication_path_cost() {
423        let tn = setup_parallel();
424        let (op_cost, mem_cost) =
425            communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 1)], false, true, None);
426        assert_eq!(op_cost, 7564.);
427        assert_eq!(mem_cost, 538.);
428    }
429
430    #[test]
431    fn test_communication_path_cost_only_ops_with_partition_cost() {
432        let tn = setup_parallel();
433        let tensor_cost = vec![20., 30., 80., 10.];
434        let (op_cost, mem_cost) = communication_path_cost(
435            tn.tensors(),
436            &[(0, 1), (2, 3), (0, 2)],
437            true,
438            true,
439            Some(&tensor_cost),
440        );
441        assert_eq!(op_cost, 520.);
442        assert_eq!(mem_cost, 538.);
443    }
444
445    #[test]
446    fn test_communication_path_cost_with_partition_cost() {
447        let tn = setup_parallel();
448        let tensor_cost = vec![20., 30., 80., 10.];
449        let (op_cost, mem_cost) = communication_path_cost(
450            tn.tensors(),
451            &[(0, 1), (2, 3), (0, 1)],
452            false,
453            true,
454            Some(&tensor_cost),
455        );
456        assert_eq!(op_cost, 7594.);
457        assert_eq!(mem_cost, 538.);
458    }
459}