tnc/tensornetwork/
contraction.rs

1//! Functionality to contract tensor networks.
2use log::debug;
3use tetra::contract;
4
5use crate::{
6    contractionpath::ContractionPath,
7    tensornetwork::{tensor::Tensor, tensordata::TensorData},
8};
9
10/// Fully contracts `tn` based on the given `contract_path` using ReplaceLeft format.
11/// Returns the resulting tensor.
12///
13/// # Examples
14/// ```
15/// # use tnc::{
16/// #   contractionpath::paths::{branchbound::BranchBound, CostType, FindPath},
17/// #   builders::sycamore_circuit::sycamore_circuit,
18/// #   tensornetwork::tensor::Tensor,
19/// #   tensornetwork::contraction::contract_tensor_network,
20/// # };
21/// # use rand::rngs::StdRng;
22/// # use rand::SeedableRng;
23/// let mut r = StdRng::seed_from_u64(42);
24/// let mut r_tn = sycamore_circuit(2, 1, &mut r);
25/// let mut opt = BranchBound::new(&r_tn, None, 20., CostType::Flops);
26/// opt.find_path();
27/// let opt_path = opt.get_best_replace_path();
28/// let result = contract_tensor_network(r_tn, &opt_path);
29/// ```
30pub fn contract_tensor_network(mut tn: Tensor, contract_path: &ContractionPath) -> Tensor {
31    debug!(len = tn.tensors().len(); "Start contracting tensor network");
32
33    // Contract child composite tensors first
34    for (index, inner_path) in &contract_path.nested {
35        let composite = std::mem::take(&mut tn.tensors[*index]);
36        let contracted = contract_tensor_network(composite, inner_path);
37        tn.tensors[*index] = contracted;
38    }
39
40    // Contract all leaf tensors
41    for (i, j) in &contract_path.toplevel {
42        debug!(i, j; "Contracting tensors");
43        tn.contract_tensors(*i, *j);
44        debug!(i, j; "Finished contracting tensors");
45    }
46    debug!("Completed tensor network contraction");
47
48    tn.tensors
49        .retain(|x| !matches!(x.tensor_data(), TensorData::Uncontracted) || x.is_composite());
50    assert!(tn.tensors().len() <= 1, "Not fully contracted");
51    tn.tensors.pop().unwrap_or(tn)
52}
53
54trait TensorContraction {
55    /// Contracts two tensors.
56    fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize);
57}
58
59impl TensorContraction for Tensor {
60    fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize) {
61        let tensor_a = std::mem::take(&mut self.tensors[tensor_a_loc]);
62        let tensor_b = std::mem::take(&mut self.tensors[tensor_b_loc]);
63
64        let mut tensor_symmetric_difference = &tensor_b ^ &tensor_a;
65
66        let Tensor {
67            legs: a_legs,
68            tensordata: a_data,
69            ..
70        } = tensor_a;
71
72        let Tensor {
73            legs: b_legs,
74            tensordata: b_data,
75            ..
76        } = tensor_b;
77
78        let result = contract(
79            &tensor_symmetric_difference.legs,
80            &a_legs,
81            a_data.into_data(),
82            &b_legs,
83            b_data.into_data(),
84        );
85
86        tensor_symmetric_difference.set_tensor_data(TensorData::Matrix(result));
87        self.tensors[tensor_a_loc] = tensor_symmetric_difference;
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    use float_cmp::assert_approx_eq;
96    use num_complex::Complex64;
97    use rustc_hash::FxHashMap;
98    use serde::Deserialize;
99
100    use crate::{
101        path,
102        tensornetwork::{contraction::TensorContraction, tensor::Tensor, tensordata::TensorData},
103    };
104
105    #[derive(Debug, Deserialize)]
106    struct TestTensor {
107        legs: Vec<usize>,
108        shape: Vec<u64>,
109        data: Vec<Complex64>,
110    }
111
112    type TestData = FxHashMap<String, TestTensor>;
113
114    static TEST_DATA: &str = include_str!("contraction_test_data.json");
115
116    fn load_test_data() -> TestData {
117        serde_json::from_str(TEST_DATA).unwrap()
118    }
119
120    #[test]
121    fn test_tensor_contraction() {
122        let mut data = load_test_data();
123        let ta = data.remove("A").unwrap();
124        let tb = data.remove("B").unwrap();
125        let tc = data.remove("C").unwrap();
126        let tab = data.remove("AxB").unwrap();
127        let tbc = data.remove("BxC").unwrap();
128
129        // t1 is of shape [3, 2, 7]
130        let mut t1 = Tensor::new(ta.legs, ta.shape);
131        // t2 is of shape [7, 8, 6]
132        let mut t2 = Tensor::new(tb.legs, tb.shape);
133        // t3 is of shape [3, 5, 8]
134        let mut t3 = Tensor::new(tc.legs, tc.shape);
135        // t12 is of shape [8, 6, 3, 2]
136        let mut t12 = Tensor::new(tab.legs, tab.shape);
137        // t23 is of shape [3, 5, 7, 6]
138        let mut t23 = Tensor::new(tbc.legs, tbc.shape);
139
140        t1.set_tensor_data(TensorData::new_from_data(
141            &t1.shape().unwrap(),
142            ta.data,
143            None,
144        ));
145
146        t2.set_tensor_data(TensorData::new_from_data(
147            &t2.shape().unwrap(),
148            tb.data,
149            None,
150        ));
151        t3.set_tensor_data(TensorData::new_from_data(
152            &t3.shape().unwrap(),
153            tc.data,
154            None,
155        ));
156
157        t12.set_tensor_data(TensorData::new_from_data(
158            &t12.shape().unwrap(),
159            tab.data,
160            None,
161        ));
162
163        t23.set_tensor_data(TensorData::new_from_data(
164            &t23.shape().unwrap(),
165            tbc.data,
166            None,
167        ));
168
169        let mut tn_12 = Tensor::new_composite(vec![t1.clone(), t2.clone(), t3.clone()]);
170
171        tn_12.contract_tensors(0, 1);
172        assert_approx_eq!(&Tensor, tn_12.tensor(0), &t12, epsilon = 1e-14);
173
174        let mut tn_23 = Tensor::new_composite(vec![t1, t2, t3]);
175
176        tn_23.contract_tensors(1, 2);
177        assert_approx_eq!(&Tensor, tn_23.tensor(1), &t23, epsilon = 1e-14);
178    }
179
180    #[test]
181    fn test_tn_contraction() {
182        let mut data = load_test_data();
183        let ta = data.remove("A").unwrap();
184        let tb = data.remove("B").unwrap();
185        let tc = data.remove("C").unwrap();
186        let tabc = data.remove("ABxC").unwrap();
187
188        // t1 is of shape [3, 2, 7]
189        let mut t1 = Tensor::new(ta.legs, ta.shape);
190        // t2 is of shape [7, 8, 6]
191        let mut t2 = Tensor::new(tb.legs, tb.shape);
192        // t3 is of shape [3, 5, 8]
193        let mut t3 = Tensor::new(tc.legs, tc.shape);
194        // tout is of shape [5, 6, 2]
195        let mut tout = Tensor::new(tabc.legs, tabc.shape);
196
197        t1.set_tensor_data(TensorData::new_from_data(
198            &t1.shape().unwrap(),
199            ta.data,
200            None,
201        ));
202
203        t2.set_tensor_data(TensorData::new_from_data(
204            &t2.shape().unwrap(),
205            tb.data,
206            None,
207        ));
208        t3.set_tensor_data(TensorData::new_from_data(
209            &t3.shape().unwrap(),
210            tc.data,
211            None,
212        ));
213        tout.set_tensor_data(TensorData::new_from_data(
214            &tout.shape().unwrap(),
215            tabc.data,
216            None,
217        ));
218
219        let tn = Tensor::new_composite(vec![t1, t2, t3]);
220        let contract_path = path![(0, 1), (0, 2)];
221
222        let result = contract_tensor_network(tn, &contract_path);
223        assert_approx_eq!(&Tensor, &result, &tout, epsilon = 1e-14);
224    }
225
226    #[test]
227    fn test_outer_product_contraction() {
228        let bond_dims = FxHashMap::from_iter([(0, 3), (1, 2)]);
229        let mut t1 = Tensor::new_from_map(vec![0], &bond_dims);
230        let mut t2 = Tensor::new_from_map(vec![1], &bond_dims);
231        t1.set_tensor_data(TensorData::new_from_data(
232            &[3],
233            vec![
234                Complex64::new(1.0, 0.0),
235                Complex64::new(2.0, 5.0),
236                Complex64::new(3.0, -1.0),
237            ],
238            None,
239        ));
240        t2.set_tensor_data(TensorData::new_from_data(
241            &[2],
242            vec![Complex64::new(-4.0, 2.0), Complex64::new(0.0, -1.0)],
243            None,
244        ));
245        let t3 = Tensor::new_composite(vec![t1, t2]);
246        let contract_path = path![(0, 1)];
247
248        let mut tn_ref = Tensor::new_from_map(vec![1, 0], &bond_dims);
249        tn_ref.set_tensor_data(TensorData::new_from_data(
250            &[2, 3],
251            vec![
252                Complex64::new(-4.0, 2.0),
253                Complex64::new(-18.0, -16.0),
254                Complex64::new(-10.0, 10.0),
255                Complex64::new(0.0, -1.0),
256                Complex64::new(5.0, -2.0),
257                Complex64::new(-1.0, -3.0),
258            ],
259            None,
260        ));
261
262        let result = contract_tensor_network(t3, &contract_path);
263        assert_approx_eq!(&Tensor, &result, &tn_ref);
264    }
265}