1use log::debug;
3use tetra::contract;
4
5use crate::{
6 contractionpath::ContractionPath,
7 tensornetwork::{tensor::Tensor, tensordata::TensorData},
8};
9
10pub fn contract_tensor_network(mut tn: Tensor, contract_path: &ContractionPath) -> Tensor {
31 debug!(len = tn.tensors().len(); "Start contracting tensor network");
32
33 for (index, inner_path) in &contract_path.nested {
35 let composite = std::mem::take(&mut tn.tensors[*index]);
36 let contracted = contract_tensor_network(composite, inner_path);
37 tn.tensors[*index] = contracted;
38 }
39
40 for (i, j) in &contract_path.toplevel {
42 debug!(i, j; "Contracting tensors");
43 tn.contract_tensors(*i, *j);
44 debug!(i, j; "Finished contracting tensors");
45 }
46 debug!("Completed tensor network contraction");
47
48 tn.tensors
49 .retain(|x| !matches!(x.tensor_data(), TensorData::Uncontracted) || x.is_composite());
50 assert!(tn.tensors().len() <= 1, "Not fully contracted");
51 tn.tensors.pop().unwrap_or(tn)
52}
53
54trait TensorContraction {
55 fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize);
57}
58
59impl TensorContraction for Tensor {
60 fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize) {
61 let tensor_a = std::mem::take(&mut self.tensors[tensor_a_loc]);
62 let tensor_b = std::mem::take(&mut self.tensors[tensor_b_loc]);
63
64 let mut tensor_symmetric_difference = &tensor_b ^ &tensor_a;
65
66 let Tensor {
67 legs: a_legs,
68 tensordata: a_data,
69 ..
70 } = tensor_a;
71
72 let Tensor {
73 legs: b_legs,
74 tensordata: b_data,
75 ..
76 } = tensor_b;
77
78 let result = contract(
79 &tensor_symmetric_difference.legs,
80 &a_legs,
81 a_data.into_data(),
82 &b_legs,
83 b_data.into_data(),
84 );
85
86 tensor_symmetric_difference.set_tensor_data(TensorData::Matrix(result));
87 self.tensors[tensor_a_loc] = tensor_symmetric_difference;
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 use float_cmp::assert_approx_eq;
96 use num_complex::Complex64;
97 use rustc_hash::FxHashMap;
98 use serde::Deserialize;
99
100 use crate::{
101 path,
102 tensornetwork::{contraction::TensorContraction, tensor::Tensor, tensordata::TensorData},
103 };
104
105 #[derive(Debug, Deserialize)]
106 struct TestTensor {
107 legs: Vec<usize>,
108 shape: Vec<u64>,
109 data: Vec<Complex64>,
110 }
111
112 type TestData = FxHashMap<String, TestTensor>;
113
114 static TEST_DATA: &str = include_str!("contraction_test_data.json");
115
116 fn load_test_data() -> TestData {
117 serde_json::from_str(TEST_DATA).unwrap()
118 }
119
120 #[test]
121 fn test_tensor_contraction() {
122 let mut data = load_test_data();
123 let ta = data.remove("A").unwrap();
124 let tb = data.remove("B").unwrap();
125 let tc = data.remove("C").unwrap();
126 let tab = data.remove("AxB").unwrap();
127 let tbc = data.remove("BxC").unwrap();
128
129 let mut t1 = Tensor::new(ta.legs, ta.shape);
131 let mut t2 = Tensor::new(tb.legs, tb.shape);
133 let mut t3 = Tensor::new(tc.legs, tc.shape);
135 let mut t12 = Tensor::new(tab.legs, tab.shape);
137 let mut t23 = Tensor::new(tbc.legs, tbc.shape);
139
140 t1.set_tensor_data(TensorData::new_from_data(
141 &t1.shape().unwrap(),
142 ta.data,
143 None,
144 ));
145
146 t2.set_tensor_data(TensorData::new_from_data(
147 &t2.shape().unwrap(),
148 tb.data,
149 None,
150 ));
151 t3.set_tensor_data(TensorData::new_from_data(
152 &t3.shape().unwrap(),
153 tc.data,
154 None,
155 ));
156
157 t12.set_tensor_data(TensorData::new_from_data(
158 &t12.shape().unwrap(),
159 tab.data,
160 None,
161 ));
162
163 t23.set_tensor_data(TensorData::new_from_data(
164 &t23.shape().unwrap(),
165 tbc.data,
166 None,
167 ));
168
169 let mut tn_12 = Tensor::new_composite(vec![t1.clone(), t2.clone(), t3.clone()]);
170
171 tn_12.contract_tensors(0, 1);
172 assert_approx_eq!(&Tensor, tn_12.tensor(0), &t12, epsilon = 1e-14);
173
174 let mut tn_23 = Tensor::new_composite(vec![t1, t2, t3]);
175
176 tn_23.contract_tensors(1, 2);
177 assert_approx_eq!(&Tensor, tn_23.tensor(1), &t23, epsilon = 1e-14);
178 }
179
180 #[test]
181 fn test_tn_contraction() {
182 let mut data = load_test_data();
183 let ta = data.remove("A").unwrap();
184 let tb = data.remove("B").unwrap();
185 let tc = data.remove("C").unwrap();
186 let tabc = data.remove("ABxC").unwrap();
187
188 let mut t1 = Tensor::new(ta.legs, ta.shape);
190 let mut t2 = Tensor::new(tb.legs, tb.shape);
192 let mut t3 = Tensor::new(tc.legs, tc.shape);
194 let mut tout = Tensor::new(tabc.legs, tabc.shape);
196
197 t1.set_tensor_data(TensorData::new_from_data(
198 &t1.shape().unwrap(),
199 ta.data,
200 None,
201 ));
202
203 t2.set_tensor_data(TensorData::new_from_data(
204 &t2.shape().unwrap(),
205 tb.data,
206 None,
207 ));
208 t3.set_tensor_data(TensorData::new_from_data(
209 &t3.shape().unwrap(),
210 tc.data,
211 None,
212 ));
213 tout.set_tensor_data(TensorData::new_from_data(
214 &tout.shape().unwrap(),
215 tabc.data,
216 None,
217 ));
218
219 let tn = Tensor::new_composite(vec![t1, t2, t3]);
220 let contract_path = path![(0, 1), (0, 2)];
221
222 let result = contract_tensor_network(tn, &contract_path);
223 assert_approx_eq!(&Tensor, &result, &tout, epsilon = 1e-14);
224 }
225
226 #[test]
227 fn test_outer_product_contraction() {
228 let bond_dims = FxHashMap::from_iter([(0, 3), (1, 2)]);
229 let mut t1 = Tensor::new_from_map(vec![0], &bond_dims);
230 let mut t2 = Tensor::new_from_map(vec![1], &bond_dims);
231 t1.set_tensor_data(TensorData::new_from_data(
232 &[3],
233 vec![
234 Complex64::new(1.0, 0.0),
235 Complex64::new(2.0, 5.0),
236 Complex64::new(3.0, -1.0),
237 ],
238 None,
239 ));
240 t2.set_tensor_data(TensorData::new_from_data(
241 &[2],
242 vec![Complex64::new(-4.0, 2.0), Complex64::new(0.0, -1.0)],
243 None,
244 ));
245 let t3 = Tensor::new_composite(vec![t1, t2]);
246 let contract_path = path![(0, 1)];
247
248 let mut tn_ref = Tensor::new_from_map(vec![1, 0], &bond_dims);
249 tn_ref.set_tensor_data(TensorData::new_from_data(
250 &[2, 3],
251 vec![
252 Complex64::new(-4.0, 2.0),
253 Complex64::new(-18.0, -16.0),
254 Complex64::new(-10.0, 10.0),
255 Complex64::new(0.0, -1.0),
256 Complex64::new(5.0, -2.0),
257 Complex64::new(-1.0, -3.0),
258 ],
259 None,
260 ));
261
262 let result = contract_tensor_network(t3, &contract_path);
263 assert_approx_eq!(&Tensor, &result, &tn_ref);
264 }
265}