1use log::debug;
3use ndarray::Axis;
4use tblis::{tensor_mult, TensorView};
5
6use crate::{
7 contractionpath::ContractionPath,
8 tensornetwork::{
9 tensor::Tensor,
10 tensordata::{DataTensor, TensorData},
11 },
12};
13
14pub fn contract_tensor_network(mut tn: Tensor, contract_path: &ContractionPath) -> Tensor {
35 debug!(len = tn.tensors().len(); "Start contracting tensor network");
36
37 for (index, inner_path) in &contract_path.nested {
39 let composite = std::mem::take(&mut tn.tensors[*index]);
40 let contracted = contract_tensor_network(composite, inner_path);
41 tn.tensors[*index] = contracted;
42 }
43
44 for (i, j) in &contract_path.toplevel {
46 debug!(i, j; "Contracting tensors");
47 tn.contract_tensors(*i, *j);
48 debug!(i, j; "Finished contracting tensors");
49 }
50 debug!("Completed tensor network contraction");
51
52 tn.tensors
53 .retain(|x| !matches!(x.tensor_data(), TensorData::Uncontracted) || x.is_composite());
54 assert!(tn.tensors().len() <= 1, "Not fully contracted");
55 tn.tensors.pop().unwrap_or(tn)
56}
57
58trait TensorContraction {
59 fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize);
61}
62
63impl TensorContraction for Tensor {
64 fn contract_tensors(&mut self, tensor_a_loc: usize, tensor_b_loc: usize) {
65 let tensor_a = std::mem::take(&mut self.tensors[tensor_a_loc]);
66 let tensor_b = std::mem::take(&mut self.tensors[tensor_b_loc]);
67
68 let mut tensor_symmetric_difference = &tensor_b ^ &tensor_a;
69
70 let Tensor {
71 legs: a_legs,
72 tensordata: a_data,
73 ..
74 } = tensor_a;
75
76 let Tensor {
77 legs: b_legs,
78 tensordata: b_data,
79 ..
80 } = tensor_b;
81
82 let result = contract_ndarrays(
83 &tensor_symmetric_difference.legs,
84 &a_legs,
85 a_data.into_data(),
86 &b_legs,
87 b_data.into_data(),
88 );
89
90 tensor_symmetric_difference.set_tensor_data(TensorData::Matrix(result));
91 self.tensors[tensor_a_loc] = tensor_symmetric_difference;
92 }
93}
94
95fn contract_ndarrays(
96 out_labels: &[usize],
97 a_labels: &[usize],
98 a_data: DataTensor,
99 b_labels: &[usize],
100 b_data: DataTensor,
101) -> DataTensor {
102 assert_eq!(a_labels.len(), a_data.ndim());
103 assert_eq!(b_labels.len(), b_data.ndim());
104
105 let mut out_shape = Vec::with_capacity(out_labels.len());
107 for label in out_labels {
108 if let Some(a_index) = a_labels.iter().position(|l| l == label) {
109 out_shape.push(a_data.len_of(Axis(a_index)));
110 } else if let Some(b_index) = b_labels.iter().position(|l| l == label) {
111 out_shape.push(b_data.len_of(Axis(b_index)));
112 } else {
113 panic!("Out label {label} not found in input a ({a_labels:?}) or b ({b_labels:?})");
114 }
115 }
116
117 let a_view = TensorView::new(a_labels, a_data.shape(), a_data.strides(), a_data.as_ptr());
119 let b_view = TensorView::new(b_labels, b_data.shape(), b_data.strides(), b_data.as_ptr());
120 let out_data = tensor_mult(out_labels, &out_shape, a_view, b_view);
121
122 DataTensor::from_shape_vec(out_shape, out_data).unwrap()
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 use approx::assert_abs_diff_eq;
130 use num_complex::Complex64;
131 use rustc_hash::FxHashMap;
132 use serde::Deserialize;
133
134 use crate::{
135 path,
136 tensornetwork::{contraction::TensorContraction, tensor::Tensor, tensordata::TensorData},
137 };
138
139 #[derive(Debug, Deserialize)]
140 struct TestTensor {
141 legs: Vec<usize>,
142 shape: Vec<u64>,
143 data: Vec<Complex64>,
144 }
145
146 type TestData = FxHashMap<String, TestTensor>;
147
148 static TEST_DATA: &str = include_str!("contraction_test_data.json");
149
150 fn load_test_data() -> TestData {
151 serde_json::from_str(TEST_DATA).unwrap()
152 }
153
154 #[test]
155 fn test_tensor_contraction() {
156 let mut data = load_test_data();
157 let ta = data.remove("A").unwrap();
158 let tb = data.remove("B").unwrap();
159 let tc = data.remove("C").unwrap();
160 let tab = data.remove("AxB").unwrap();
161 let tbc = data.remove("BxC").unwrap();
162
163 let mut t1 = Tensor::new(ta.legs, ta.shape);
165 let mut t2 = Tensor::new(tb.legs, tb.shape);
167 let mut t3 = Tensor::new(tc.legs, tc.shape);
169 let mut t12 = Tensor::new(tab.legs, tab.shape);
171 let mut t23 = Tensor::new(tbc.legs, tbc.shape);
173
174 t1.set_tensor_data(TensorData::new_from_data(&t1.shape().unwrap(), ta.data));
175
176 t2.set_tensor_data(TensorData::new_from_data(&t2.shape().unwrap(), tb.data));
177 t3.set_tensor_data(TensorData::new_from_data(&t3.shape().unwrap(), tc.data));
178
179 t12.set_tensor_data(TensorData::new_from_data(&t12.shape().unwrap(), tab.data));
180
181 t23.set_tensor_data(TensorData::new_from_data(&t23.shape().unwrap(), tbc.data));
182
183 let mut tn_12 = Tensor::new_composite(vec![t1.clone(), t2.clone(), t3.clone()]);
184
185 tn_12.contract_tensors(0, 1);
186 assert_abs_diff_eq!(tn_12.tensor(0), &t12, epsilon = 1e-14);
187
188 let mut tn_23 = Tensor::new_composite(vec![t1, t2, t3]);
189
190 tn_23.contract_tensors(1, 2);
191 assert_abs_diff_eq!(tn_23.tensor(1), &t23, epsilon = 1e-14);
192 }
193
194 #[test]
195 fn test_tn_contraction() {
196 let mut data = load_test_data();
197 let ta = data.remove("A").unwrap();
198 let tb = data.remove("B").unwrap();
199 let tc = data.remove("C").unwrap();
200 let tabc = data.remove("ABxC").unwrap();
201
202 let mut t1 = Tensor::new(ta.legs, ta.shape);
204 let mut t2 = Tensor::new(tb.legs, tb.shape);
206 let mut t3 = Tensor::new(tc.legs, tc.shape);
208 let mut tout = Tensor::new(tabc.legs, tabc.shape);
210
211 t1.set_tensor_data(TensorData::new_from_data(&t1.shape().unwrap(), ta.data));
212
213 t2.set_tensor_data(TensorData::new_from_data(&t2.shape().unwrap(), tb.data));
214 t3.set_tensor_data(TensorData::new_from_data(&t3.shape().unwrap(), tc.data));
215 tout.set_tensor_data(TensorData::new_from_data(&tout.shape().unwrap(), tabc.data));
216
217 let tn = Tensor::new_composite(vec![t1, t2, t3]);
218 let contract_path = path![(0, 1), (0, 2)];
219
220 let result = contract_tensor_network(tn, &contract_path);
221 assert_abs_diff_eq!(result, &tout, epsilon = 1e-14);
222 }
223
224 #[test]
225 fn test_outer_product_contraction() {
226 let bond_dims = FxHashMap::from_iter([(0, 3), (1, 2)]);
227 let mut t1 = Tensor::new_from_map(vec![0], &bond_dims);
228 let mut t2 = Tensor::new_from_map(vec![1], &bond_dims);
229 t1.set_tensor_data(TensorData::new_from_data(
230 &[3],
231 vec![
232 Complex64::new(1.0, 0.0),
233 Complex64::new(2.0, 5.0),
234 Complex64::new(3.0, -1.0),
235 ],
236 ));
237 t2.set_tensor_data(TensorData::new_from_data(
238 &[2],
239 vec![Complex64::new(-4.0, 2.0), Complex64::new(0.0, -1.0)],
240 ));
241 let t3 = Tensor::new_composite(vec![t1, t2]);
242 let contract_path = path![(0, 1)];
243
244 let mut tn_ref = Tensor::new_from_map(vec![1, 0], &bond_dims);
245 tn_ref.set_tensor_data(TensorData::new_from_data(
246 &[2, 3],
247 vec![
248 Complex64::new(-4.0, 2.0),
249 Complex64::new(-18.0, -16.0),
250 Complex64::new(-10.0, 10.0),
251 Complex64::new(0.0, -1.0),
252 Complex64::new(5.0, -2.0),
253 Complex64::new(-1.0, -3.0),
254 ],
255 ));
256
257 let result = contract_tensor_network(t3, &contract_path);
258 assert_abs_diff_eq!(result, &tn_ref);
259 }
260
261 #[test]
262 fn dimension_order() {
263 let mut ket0 = Tensor::new_from_const(vec![0], 2);
264 ket0.set_tensor_data(TensorData::new_from_data(
265 &[2],
266 vec![Complex64::ONE, Complex64::ZERO],
267 ));
268
269 let mut mat = Tensor::new_from_const(vec![1, 0], 2);
270 mat.set_tensor_data(TensorData::new_from_data(
271 &[2, 2],
272 vec![
273 Complex64::new(1.0, 0.0),
274 Complex64::new(2.0, 0.0),
275 Complex64::new(3.0, 0.0),
276 Complex64::new(4.0, 0.0),
277 ],
278 ));
279
280 let tn = Tensor::new_composite(vec![ket0, mat]);
281 let contract_path = path![(0, 1)];
282
283 let mut tn_ref = Tensor::new_from_const(vec![1], 2);
284 tn_ref.set_tensor_data(TensorData::new_from_data(
285 &[2],
286 vec![Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
287 ));
288
289 let result = contract_tensor_network(tn, &contract_path);
290 assert_abs_diff_eq!(result, &tn_ref);
291 }
292}