Skip to main content

tnc/tensornetwork/
contraction.rs

1//! Functionality to contract tensor networks.
2use log::debug;
3use ndarray::Axis;
4use tblis::{tensor_mult, TensorView};
5
6use crate::{
7    contractionpath::ContractionPath,
8    tensornetwork::{
9        tensor::Tensor,
10        tensordata::{DataTensor, TensorData},
11    },
12};
13
14/// Fully contracts `tn` based on the given `contract_path` using ReplaceLeft format.
15/// Returns the resulting tensor.
16///
17/// # Examples
18/// ```
19/// # use tnc::{
20/// #   contractionpath::paths::{cotengrust::{Cotengrust, OptMethod}, FindPath},
21/// #   builders::sycamore_circuit::sycamore_circuit,
22/// #   tensornetwork::tensor::Tensor,
23/// #   tensornetwork::contraction::contract_tensor_network,
24/// # };
25/// # use rand::rngs::StdRng;
26/// # use rand::SeedableRng;
27/// let mut r = StdRng::seed_from_u64(42);
28/// let mut r_tn = sycamore_circuit(2, 1, &mut r).into_expectation_value_network();
29/// let mut opt = Cotengrust::new(&r_tn, OptMethod::Greedy);
30/// opt.find_path();
31/// let opt_path = opt.get_best_replace_path();
32/// let result = contract_tensor_network(r_tn, &opt_path);
33/// ```
34pub fn contract_tensor_network(mut tn: Tensor, contract_path: &ContractionPath) -> Tensor {
35    debug!(len = tn.tensors().len(); "Start contracting tensor network");
36
37    // Contract child composite tensors first
38    for (index, inner_path) in &contract_path.nested {
39        let composite = std::mem::take(&mut tn.tensors[*index]);
40        let contracted = contract_tensor_network(composite, inner_path);
41        tn.tensors[*index] = contracted;
42    }
43
44    // Contract all leaf tensors
45    for (i, j) in &contract_path.toplevel {
46        debug!(i, j; "Contracting tensors");
47        tn.contract_tensors(*i, *j);
48        debug!(i, j; "Finished contracting tensors");
49    }
50    debug!("Completed tensor network contraction");
51
52    tn.tensors
53        .retain(|x| !matches!(x.tensor_data(), TensorData::Uncontracted) || x.is_composite());
54    assert!(tn.tensors().len() <= 1, "Not fully contracted");
55    tn.tensors.pop().unwrap_or(tn)
56}
57
58trait TensorContraction {
59    /// Contracts two tensors.
60    fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize);
61}
62
63impl TensorContraction for Tensor {
64    fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize) {
65        let tensor_a = std::mem::take(&mut self.tensors[tensor_a_loc]);
66        let tensor_b = std::mem::take(&mut self.tensors[tensor_b_loc]);
67
68        let mut tensor_symmetric_difference = &tensor_b ^ &tensor_a;
69
70        let Tensor {
71            legs: a_legs,
72            tensordata: a_data,
73            ..
74        } = tensor_a;
75
76        let Tensor {
77            legs: b_legs,
78            tensordata: b_data,
79            ..
80        } = tensor_b;
81
82        let result = contract_ndarrays(
83            &tensor_symmetric_difference.legs,
84            &a_legs,
85            a_data.into_data(),
86            &b_legs,
87            b_data.into_data(),
88        );
89
90        tensor_symmetric_difference.set_tensor_data(TensorData::Matrix(result));
91        self.tensors[tensor_a_loc] = tensor_symmetric_difference;
92    }
93}
94
95fn contract_ndarrays(
96    out_labels: &[usize],
97    a_labels: &[usize],
98    a_data: DataTensor,
99    b_labels: &[usize],
100    b_data: DataTensor,
101) -> DataTensor {
102    assert_eq!(a_labels.len(), a_data.ndim());
103    assert_eq!(b_labels.len(), b_data.ndim());
104
105    // Find output shape
106    let mut out_shape = Vec::with_capacity(out_labels.len());
107    for label in out_labels {
108        if let Some(a_index) = a_labels.iter().position(|l| l == label) {
109            out_shape.push(a_data.len_of(Axis(a_index)));
110        } else if let Some(b_index) = b_labels.iter().position(|l| l == label) {
111            out_shape.push(b_data.len_of(Axis(b_index)));
112        } else {
113            panic!("Out label {label} not found in input a ({a_labels:?}) or b ({b_labels:?})");
114        }
115    }
116
117    // Contract with TBLIS
118    let a_view = TensorView::new(a_labels, a_data.shape(), a_data.strides(), a_data.as_ptr());
119    let b_view = TensorView::new(b_labels, b_data.shape(), b_data.strides(), b_data.as_ptr());
120    let out_data = tensor_mult(out_labels, &out_shape, a_view, b_view);
121
122    DataTensor::from_shape_vec(out_shape, out_data).unwrap()
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    use approx::assert_abs_diff_eq;
130    use num_complex::Complex64;
131    use rustc_hash::FxHashMap;
132    use serde::Deserialize;
133
134    use crate::{
135        path,
136        tensornetwork::{contraction::TensorContraction, tensor::Tensor, tensordata::TensorData},
137    };
138
139    #[derive(Debug, Deserialize)]
140    struct TestTensor {
141        legs: Vec<usize>,
142        shape: Vec<u64>,
143        data: Vec<Complex64>,
144    }
145
146    type TestData = FxHashMap<String, TestTensor>;
147
148    static TEST_DATA: &str = include_str!("contraction_test_data.json");
149
150    fn load_test_data() -> TestData {
151        serde_json::from_str(TEST_DATA).unwrap()
152    }
153
154    #[test]
155    fn test_tensor_contraction() {
156        let mut data = load_test_data();
157        let ta = data.remove("A").unwrap();
158        let tb = data.remove("B").unwrap();
159        let tc = data.remove("C").unwrap();
160        let tab = data.remove("AxB").unwrap();
161        let tbc = data.remove("BxC").unwrap();
162
163        // t1 is of shape [3, 2, 7]
164        let mut t1 = Tensor::new(ta.legs, ta.shape);
165        // t2 is of shape [7, 8, 6]
166        let mut t2 = Tensor::new(tb.legs, tb.shape);
167        // t3 is of shape [3, 5, 8]
168        let mut t3 = Tensor::new(tc.legs, tc.shape);
169        // t12 is of shape [8, 6, 3, 2]
170        let mut t12 = Tensor::new(tab.legs, tab.shape);
171        // t23 is of shape [3, 5, 7, 6]
172        let mut t23 = Tensor::new(tbc.legs, tbc.shape);
173
174        t1.set_tensor_data(TensorData::new_from_data(&t1.shape().unwrap(), ta.data));
175
176        t2.set_tensor_data(TensorData::new_from_data(&t2.shape().unwrap(), tb.data));
177        t3.set_tensor_data(TensorData::new_from_data(&t3.shape().unwrap(), tc.data));
178
179        t12.set_tensor_data(TensorData::new_from_data(&t12.shape().unwrap(), tab.data));
180
181        t23.set_tensor_data(TensorData::new_from_data(&t23.shape().unwrap(), tbc.data));
182
183        let mut tn_12 = Tensor::new_composite(vec![t1.clone(), t2.clone(), t3.clone()]);
184
185        tn_12.contract_tensors(0, 1);
186        assert_abs_diff_eq!(tn_12.tensor(0), &t12, epsilon = 1e-14);
187
188        let mut tn_23 = Tensor::new_composite(vec![t1, t2, t3]);
189
190        tn_23.contract_tensors(1, 2);
191        assert_abs_diff_eq!(tn_23.tensor(1), &t23, epsilon = 1e-14);
192    }
193
194    #[test]
195    fn test_tn_contraction() {
196        let mut data = load_test_data();
197        let ta = data.remove("A").unwrap();
198        let tb = data.remove("B").unwrap();
199        let tc = data.remove("C").unwrap();
200        let tabc = data.remove("ABxC").unwrap();
201
202        // t1 is of shape [3, 2, 7]
203        let mut t1 = Tensor::new(ta.legs, ta.shape);
204        // t2 is of shape [7, 8, 6]
205        let mut t2 = Tensor::new(tb.legs, tb.shape);
206        // t3 is of shape [3, 5, 8]
207        let mut t3 = Tensor::new(tc.legs, tc.shape);
208        // tout is of shape [5, 6, 2]
209        let mut tout = Tensor::new(tabc.legs, tabc.shape);
210
211        t1.set_tensor_data(TensorData::new_from_data(&t1.shape().unwrap(), ta.data));
212
213        t2.set_tensor_data(TensorData::new_from_data(&t2.shape().unwrap(), tb.data));
214        t3.set_tensor_data(TensorData::new_from_data(&t3.shape().unwrap(), tc.data));
215        tout.set_tensor_data(TensorData::new_from_data(&tout.shape().unwrap(), tabc.data));
216
217        let tn = Tensor::new_composite(vec![t1, t2, t3]);
218        let contract_path = path![(0, 1), (0, 2)];
219
220        let result = contract_tensor_network(tn, &contract_path);
221        assert_abs_diff_eq!(result, &tout, epsilon = 1e-14);
222    }
223
224    #[test]
225    fn test_outer_product_contraction() {
226        let bond_dims = FxHashMap::from_iter([(0, 3), (1, 2)]);
227        let mut t1 = Tensor::new_from_map(vec![0], &bond_dims);
228        let mut t2 = Tensor::new_from_map(vec![1], &bond_dims);
229        t1.set_tensor_data(TensorData::new_from_data(
230            &[3],
231            vec![
232                Complex64::new(1.0, 0.0),
233                Complex64::new(2.0, 5.0),
234                Complex64::new(3.0, -1.0),
235            ],
236        ));
237        t2.set_tensor_data(TensorData::new_from_data(
238            &[2],
239            vec![Complex64::new(-4.0, 2.0), Complex64::new(0.0, -1.0)],
240        ));
241        let t3 = Tensor::new_composite(vec![t1, t2]);
242        let contract_path = path![(0, 1)];
243
244        let mut tn_ref = Tensor::new_from_map(vec![1, 0], &bond_dims);
245        tn_ref.set_tensor_data(TensorData::new_from_data(
246            &[2, 3],
247            vec![
248                Complex64::new(-4.0, 2.0),
249                Complex64::new(-18.0, -16.0),
250                Complex64::new(-10.0, 10.0),
251                Complex64::new(0.0, -1.0),
252                Complex64::new(5.0, -2.0),
253                Complex64::new(-1.0, -3.0),
254            ],
255        ));
256
257        let result = contract_tensor_network(t3, &contract_path);
258        assert_abs_diff_eq!(result, &tn_ref);
259    }
260
261    #[test]
262    fn dimension_order() {
263        let mut ket0 = Tensor::new_from_const(vec![0], 2);
264        ket0.set_tensor_data(TensorData::new_from_data(
265            &[2],
266            vec![Complex64::ONE, Complex64::ZERO],
267        ));
268
269        let mut mat = Tensor::new_from_const(vec![1, 0], 2);
270        mat.set_tensor_data(TensorData::new_from_data(
271            &[2, 2],
272            vec![
273                Complex64::new(1.0, 0.0),
274                Complex64::new(2.0, 0.0),
275                Complex64::new(3.0, 0.0),
276                Complex64::new(4.0, 0.0),
277            ],
278        ));
279
280        let tn = Tensor::new_composite(vec![ket0, mat]);
281        let contract_path = path![(0, 1)];
282
283        let mut tn_ref = Tensor::new_from_const(vec![1], 2);
284        tn_ref.set_tensor_data(TensorData::new_from_data(
285            &[2],
286            vec![Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
287        ));
288
289        let result = contract_tensor_network(tn, &contract_path);
290        assert_abs_diff_eq!(result, &tn_ref);
291    }
292}