Skip to main content

tnc/contractionpath/
contraction_cost.rs

1//! Different methods to compute the computational and memory cost of contraction
2//! paths.
3
4use num_complex::Complex64;
5
6use crate::{
7    contractionpath::{ContractionPath, SimplePathRef},
8    tensornetwork::tensor::Tensor,
9};
10
11/// Returns Schroedinger contraction time complexity of contracting two [`Tensor`]
12/// objects. Considers cost of complex operations.
13///
14/// # Examples
15/// ```
16/// # use tnc::tensornetwork::tensor::Tensor;
17/// # use tnc::contractionpath::contraction_cost::contract_cost_tensors;
18/// # use rustc_hash::FxHashMap;
19/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
20/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims);
21/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
22/// // result = [0, 1, 2, 3, 4] // cost of (9-1)*54*5005 = 350350;
23/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
24/// assert_eq!(contract_cost_tensors(&tn.tensor(0), &tn.tensor(1)), 350350.);
25/// ```
26pub fn contract_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
27    let final_dims = t_1 ^ t_2;
28    let shared_dims = t_1 & t_2;
29
30    let single_loop_cost = shared_dims.size();
31    (single_loop_cost - 1f64).mul_add(2f64, single_loop_cost * 6f64) * final_dims.size()
32}
33
34/// Returns Schroedinger contraction time complexity of contracting two [`Tensor`]
35/// objects. Naive op cost, does not consider costs of multiplication.
36///
37/// # Examples
38/// ```
39/// # use tnc::tensornetwork::tensor::Tensor;
40/// # use tnc::contractionpath::contraction_cost::contract_op_cost_tensors;
41/// # use rustc_hash::FxHashMap;
42/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
43/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims);
44/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims);
45/// // result = [0, 1, 2, 3, 4] // cost of 5*7*9*11*13 = 45045;
46/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
47/// assert_eq!(contract_op_cost_tensors(&tn.tensor(0), &tn.tensor(1)), 45045.);
48/// ```
49#[inline]
50pub fn contract_op_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
51    let all_dims = t_1 | t_2;
52    all_dims.size()
53}
54
55/// Returns Schroedinger contraction space complexity of contracting two [`Tensor`]
56/// objects.
57///
58/// # Examples
59/// ```
60/// # use tnc::tensornetwork::tensor::Tensor;
61/// # use tnc::contractionpath::contraction_cost::contract_size_tensors;
62/// # use rustc_hash::FxHashMap;
63/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
64/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims); // 315 entries
65/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims); // 1287 entries
66/// // result = [0, 1, 3, 4] //  5005 entries -> total 6607 entries
67/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
68/// assert_eq!(contract_size_tensors(&tn.tensor(0), &tn.tensor(1)), 6607.);
69/// ```
70#[inline]
71pub fn contract_size_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
72    let diff = t_1 ^ t_2;
73    diff.size() + t_1.size() + t_2.size()
74}
75
76/// Returns contraction space complexity in bytes of contracting two [`Tensor`]
77/// objects.
78///
79/// # Examples
80/// ```
81/// # use tnc::tensornetwork::tensor::Tensor;
82/// # use tnc::contractionpath::contraction_cost::contract_size_tensors_bytes;
83/// # use rustc_hash::FxHashMap;
84/// let bond_dims = FxHashMap::from_iter([(0, 5),(1, 7), (2, 9), (3, 11), (4, 13)]);
85/// let tensor1 = Tensor::new_from_map(vec![0, 1, 2], &bond_dims); // 315 entries
86/// let tensor2 = Tensor::new_from_map(vec![2, 3, 4], &bond_dims); // 1287 entries
87/// // result = [0, 1, 3, 4] //  5005 entries -> total 6607 entries
88/// let tn = Tensor::new_composite(vec![tensor1, tensor2]);
89/// assert_eq!(contract_size_tensors_bytes(&tn.tensor(0), &tn.tensor(1)), 6607. * 16.);
90/// ```
91#[inline]
92pub fn contract_size_tensors_bytes(i: &Tensor, j: &Tensor) -> f64 {
93    contract_size_tensors(i, j) * std::mem::size_of::<Complex64>() as f64
94}
95
96/// Returns Schroedinger contraction time and space complexity of fully contracting
97/// the input tensors.
98///
99/// # Arguments
100/// * `inputs` - Tensors to contract
101/// * `contract_path`  - Contraction order (replace path)
102/// * `only_count_ops` - If `true`, ignores cost of complex multiplication and addition and only counts number of operations
103#[inline]
104pub fn contract_path_cost(
105    inputs: &[Tensor],
106    contract_path: &ContractionPath,
107    only_count_ops: bool,
108) -> (f64, f64) {
109    let cost_function = if only_count_ops {
110        contract_op_cost_tensors
111    } else {
112        contract_cost_tensors
113    };
114    contract_path_custom_cost(inputs, contract_path, cost_function, contract_size_tensors)
115}
116
117/// Returns Schroedinger contraction time and space complexity of fully contracting
118/// the input tensors.
119///
120/// # Arguments
121/// * `inputs` - Tensors to contract
122/// * `contract_path`  - Contraction order (replace path)
123/// * `cost_function` - Function to calculate cost of contracting two tensors
124fn contract_path_custom_cost(
125    inputs: &[Tensor],
126    contract_path: &ContractionPath,
127    cost_function: fn(&Tensor, &Tensor) -> f64,
128    size_function: fn(&Tensor, &Tensor) -> f64,
129) -> (f64, f64) {
130    let mut op_cost = 0f64;
131    let mut mem_cost = 0f64;
132    let mut inputs = inputs.to_vec();
133
134    for (i, path) in &contract_path.nested {
135        let costs =
136            contract_path_custom_cost(inputs[*i].tensors(), path, cost_function, size_function);
137        op_cost += costs.0;
138        mem_cost = mem_cost.max(costs.1);
139        inputs[*i] = inputs[*i].external_tensor();
140    }
141
142    for &(i, j) in &contract_path.toplevel {
143        op_cost += cost_function(&inputs[i], &inputs[j]);
144        let ij = &inputs[i] ^ &inputs[j];
145        let new_mem_cost = size_function(&inputs[i], &inputs[j]);
146        mem_cost = mem_cost.max(new_mem_cost);
147        inputs[i] = ij;
148    }
149
150    (op_cost, mem_cost)
151}
152
153/// Returns Schroedinger contraction time complexity using the critical path metric
154/// and using the sum metric. Additionally returns the space complexity.
155#[inline]
156pub fn communication_path_op_costs(
157    inputs: &[Tensor],
158    contract_path: SimplePathRef,
159    only_count_ops: bool,
160    tensor_cost: Option<&[f64]>,
161) -> ((f64, f64), f64) {
162    let (parallel_cost, _) =
163        communication_path_cost(inputs, contract_path, only_count_ops, true, tensor_cost);
164    let (serial_cost, mem_cost) =
165        communication_path_cost(inputs, contract_path, only_count_ops, false, tensor_cost);
166    ((parallel_cost, serial_cost), mem_cost)
167}
168
169/// Returns Schroedinger contraction time and space complexity of fully contracting
170/// the input tensors assuming all operations occur in parallel.
171///
172/// # Arguments
173/// * `inputs` - Tensors to contract
174/// * `contract_path`  - Contraction order (replace path)
175/// * `only_count_ops` - If `true`, ignores cost of complex multiplication and addition and only counts number of operations
176/// * `only_circital_path` - If `true`, only counts the cost along the critical path, otherwise the sum of all costs
177/// * `tensor_costs` - Initial cost for each tensor
178pub fn communication_path_cost(
179    inputs: &[Tensor],
180    contract_path: SimplePathRef,
181    only_count_ops: bool,
182    only_critical_path: bool,
183    tensor_cost: Option<&[f64]>,
184) -> (f64, f64) {
185    let cost_function = if only_count_ops {
186        contract_op_cost_tensors
187    } else {
188        contract_cost_tensors
189    };
190    let tensor_cost = if let Some(tensor_cost) = tensor_cost {
191        assert_eq!(inputs.len(), tensor_cost.len());
192        tensor_cost
193    } else {
194        &vec![0f64; inputs.len()]
195    };
196    if inputs.len() == 1 {
197        return (tensor_cost[0], tensor_cost[0]);
198    }
199
200    communication_path_custom_cost(
201        inputs,
202        contract_path,
203        cost_function,
204        only_critical_path,
205        tensor_cost,
206    )
207}
208
209/// Returns Schroedinger contraction time and space complexity of fully contracting
210/// the input tensors assuming all operations occur in parallel.
211///
212/// # Arguments
213/// * `inputs` - Tensors to contract
214/// * `contract_path`  - Contraction order (replace path)
215/// * `cost_function` - Function to calculate cost of contracting two tensors
216/// * `tensor_costs` - Initial cost for each tensor
217fn communication_path_custom_cost(
218    inputs: &[Tensor],
219    contract_path: SimplePathRef,
220    cost_function: fn(&Tensor, &Tensor) -> f64,
221    only_critical_path: bool,
222    tensor_cost: &[f64],
223) -> (f64, f64) {
224    let mut op_cost = 0f64;
225    let mut mem_cost = 0f64;
226    let mut inputs = inputs.to_vec();
227    let mut tensor_cost = tensor_cost.to_vec();
228
229    for &(i, j) in contract_path {
230        let ij = &inputs[i] ^ &inputs[j];
231        let new_mem_cost = contract_size_tensors(&inputs[i], &inputs[j]);
232        mem_cost = mem_cost.max(new_mem_cost);
233
234        op_cost = if only_critical_path {
235            cost_function(&inputs[i], &inputs[j]) + tensor_cost[i].max(tensor_cost[j])
236        } else {
237            cost_function(&inputs[i], &inputs[j]) + tensor_cost[i] + tensor_cost[j]
238        };
239        tensor_cost[i] = op_cost;
240        inputs[i] = ij;
241    }
242
243    (op_cost, mem_cost)
244}
245
246/// Computes the max memory requirements for contracting the tensor network using the
247/// given path. Uses `memory_estimator` to compute the memory required to contract
248/// two tensors.
249///
250/// Candidates for `memory_estimator` are e.g.:
251/// - [`contract_size_tensors`]
252/// - [`contract_size_tensors_bytes`]
253#[inline]
254pub fn compute_memory_requirements(
255    inputs: &[Tensor],
256    contract_path: &ContractionPath,
257    memory_estimator: fn(&Tensor, &Tensor) -> f64,
258) -> f64 {
259    fn id(_: &Tensor, _: &Tensor) -> f64 {
260        0.0
261    }
262    let (_, mem) = contract_path_custom_cost(inputs, contract_path, id, memory_estimator);
263    mem
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    use rustc_hash::FxHashMap;
271
272    use crate::path;
273    use crate::tensornetwork::tensor::Tensor;
274
275    fn setup_simple() -> Tensor {
276        let bond_dims =
277            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
278        Tensor::new_composite(vec![
279            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
280            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
281            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
282        ])
283    }
284
285    fn setup_complex() -> Tensor {
286        let bond_dims = FxHashMap::from_iter([
287            (0, 5),
288            (1, 2),
289            (2, 6),
290            (3, 8),
291            (4, 1),
292            (5, 3),
293            (6, 4),
294            (7, 3),
295            (8, 2),
296            (9, 2),
297        ]);
298        let t1_tensors = vec![
299            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
300            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
301            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
302        ];
303        let t1 = Tensor::new_composite(t1_tensors);
304
305        let t2_tensors = vec![
306            Tensor::new_from_map(vec![5, 6, 8], &bond_dims),
307            Tensor::new_from_map(vec![7, 8, 9], &bond_dims),
308        ];
309        let t2 = Tensor::new_composite(t2_tensors);
310        Tensor::new_composite(vec![t1, t2])
311    }
312
313    fn setup_parallel() -> Tensor {
314        let bond_dims =
315            FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
316        Tensor::new_composite(vec![
317            Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
318            Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
319            Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
320            Tensor::new_from_map(vec![5, 6], &bond_dims),
321        ])
322    }
323
324    #[test]
325    fn test_contract_path_cost() {
326        let tn = setup_simple();
327        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], false);
328        assert_eq!(op_cost, 4540.);
329        assert_eq!(mem_cost, 538.);
330        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], false);
331        assert_eq!(op_cost, 49296.);
332        assert_eq!(mem_cost, 1176.);
333    }
334
335    #[test]
336    fn test_contract_complex_path_cost() {
337        let tn = setup_complex();
338        let (op_cost, mem_cost) = contract_path_cost(
339            tn.tensors(),
340            &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
341            false,
342        );
343        assert_eq!(op_cost, 11188.);
344        assert_eq!(mem_cost, 538.);
345    }
346
347    #[test]
348    fn test_contract_path_cost_only_ops() {
349        let tn = setup_simple();
350        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], true);
351        assert_eq!(op_cost, 600.);
352        assert_eq!(mem_cost, 538.);
353        let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], true);
354        assert_eq!(op_cost, 6336.);
355        assert_eq!(mem_cost, 1176.);
356    }
357
358    #[test]
359    fn test_contract_path_complex_cost_only_ops() {
360        let tn = setup_complex();
361        let (op_cost, mem_cost) = contract_path_cost(
362            tn.tensors(),
363            &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
364            true,
365        );
366        assert_eq!(op_cost, 1464.);
367        assert_eq!(mem_cost, 538.);
368    }
369
370    #[test]
371    fn test_communication_path_cost_only_ops() {
372        let tn = setup_parallel();
373        let (op_cost, mem_cost) =
374            communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 2)], true, true, None);
375        assert_eq!(op_cost, 490.);
376        assert_eq!(mem_cost, 538.);
377    }
378
379    #[test]
380    fn test_communication_path_cost() {
381        let tn = setup_parallel();
382        let (op_cost, mem_cost) =
383            communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 1)], false, true, None);
384        assert_eq!(op_cost, 7564.);
385        assert_eq!(mem_cost, 538.);
386    }
387
388    #[test]
389    fn test_communication_path_cost_only_ops_with_partition_cost() {
390        let tn = setup_parallel();
391        let tensor_cost = vec![20., 30., 80., 10.];
392        let (op_cost, mem_cost) = communication_path_cost(
393            tn.tensors(),
394            &[(0, 1), (2, 3), (0, 2)],
395            true,
396            true,
397            Some(&tensor_cost),
398        );
399        assert_eq!(op_cost, 520.);
400        assert_eq!(mem_cost, 538.);
401    }
402
403    #[test]
404    fn test_communication_path_cost_with_partition_cost() {
405        let tn = setup_parallel();
406        let tensor_cost = vec![20., 30., 80., 10.];
407        let (op_cost, mem_cost) = communication_path_cost(
408            tn.tensors(),
409            &[(0, 1), (2, 3), (0, 1)],
410            false,
411            true,
412            Some(&tensor_cost),
413        );
414        assert_eq!(op_cost, 7594.);
415        assert_eq!(mem_cost, 538.);
416    }
417}