1use num_complex::Complex64;
5
6use crate::{
7 contractionpath::{ContractionPath, SimplePathRef},
8 tensornetwork::tensor::Tensor,
9};
10
11pub fn contract_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
27 let final_dims = t_1 ^ t_2;
28 let shared_dims = t_1 & t_2;
29
30 let single_loop_cost = shared_dims.size();
31 (single_loop_cost - 1f64).mul_add(2f64, single_loop_cost * 6f64) * final_dims.size()
32}
33
34#[inline]
50pub fn contract_op_cost_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
51 let all_dims = t_1 | t_2;
52 all_dims.size()
53}
54
55#[inline]
71pub fn contract_size_tensors(t_1: &Tensor, t_2: &Tensor) -> f64 {
72 let diff = t_1 ^ t_2;
73 diff.size() + t_1.size() + t_2.size()
74}
75
76#[inline]
92pub fn contract_size_tensors_bytes(i: &Tensor, j: &Tensor) -> f64 {
93 contract_size_tensors(i, j) * std::mem::size_of::<Complex64>() as f64
94}
95
96#[inline]
104pub fn contract_path_cost(
105 inputs: &[Tensor],
106 contract_path: &ContractionPath,
107 only_count_ops: bool,
108) -> (f64, f64) {
109 let cost_function = if only_count_ops {
110 contract_op_cost_tensors
111 } else {
112 contract_cost_tensors
113 };
114 contract_path_custom_cost(inputs, contract_path, cost_function, contract_size_tensors)
115}
116
117fn contract_path_custom_cost(
125 inputs: &[Tensor],
126 contract_path: &ContractionPath,
127 cost_function: fn(&Tensor, &Tensor) -> f64,
128 size_function: fn(&Tensor, &Tensor) -> f64,
129) -> (f64, f64) {
130 let mut op_cost = 0f64;
131 let mut mem_cost = 0f64;
132 let mut inputs = inputs.to_vec();
133
134 for (i, path) in &contract_path.nested {
135 let costs =
136 contract_path_custom_cost(inputs[*i].tensors(), path, cost_function, size_function);
137 op_cost += costs.0;
138 mem_cost = mem_cost.max(costs.1);
139 inputs[*i] = inputs[*i].external_tensor();
140 }
141
142 for &(i, j) in &contract_path.toplevel {
143 op_cost += cost_function(&inputs[i], &inputs[j]);
144 let ij = &inputs[i] ^ &inputs[j];
145 let new_mem_cost = size_function(&inputs[i], &inputs[j]);
146 mem_cost = mem_cost.max(new_mem_cost);
147 inputs[i] = ij;
148 }
149
150 (op_cost, mem_cost)
151}
152
153#[inline]
156pub fn communication_path_op_costs(
157 inputs: &[Tensor],
158 contract_path: SimplePathRef,
159 only_count_ops: bool,
160 tensor_cost: Option<&[f64]>,
161) -> ((f64, f64), f64) {
162 let (parallel_cost, _) =
163 communication_path_cost(inputs, contract_path, only_count_ops, true, tensor_cost);
164 let (serial_cost, mem_cost) =
165 communication_path_cost(inputs, contract_path, only_count_ops, false, tensor_cost);
166 ((parallel_cost, serial_cost), mem_cost)
167}
168
169pub fn communication_path_cost(
179 inputs: &[Tensor],
180 contract_path: SimplePathRef,
181 only_count_ops: bool,
182 only_critical_path: bool,
183 tensor_cost: Option<&[f64]>,
184) -> (f64, f64) {
185 let cost_function = if only_count_ops {
186 contract_op_cost_tensors
187 } else {
188 contract_cost_tensors
189 };
190 let tensor_cost = if let Some(tensor_cost) = tensor_cost {
191 assert_eq!(inputs.len(), tensor_cost.len());
192 tensor_cost
193 } else {
194 &vec![0f64; inputs.len()]
195 };
196 if inputs.len() == 1 {
197 return (tensor_cost[0], tensor_cost[0]);
198 }
199
200 communication_path_custom_cost(
201 inputs,
202 contract_path,
203 cost_function,
204 only_critical_path,
205 tensor_cost,
206 )
207}
208
209fn communication_path_custom_cost(
218 inputs: &[Tensor],
219 contract_path: SimplePathRef,
220 cost_function: fn(&Tensor, &Tensor) -> f64,
221 only_critical_path: bool,
222 tensor_cost: &[f64],
223) -> (f64, f64) {
224 let mut op_cost = 0f64;
225 let mut mem_cost = 0f64;
226 let mut inputs = inputs.to_vec();
227 let mut tensor_cost = tensor_cost.to_vec();
228
229 for &(i, j) in contract_path {
230 let ij = &inputs[i] ^ &inputs[j];
231 let new_mem_cost = contract_size_tensors(&inputs[i], &inputs[j]);
232 mem_cost = mem_cost.max(new_mem_cost);
233
234 op_cost = if only_critical_path {
235 cost_function(&inputs[i], &inputs[j]) + tensor_cost[i].max(tensor_cost[j])
236 } else {
237 cost_function(&inputs[i], &inputs[j]) + tensor_cost[i] + tensor_cost[j]
238 };
239 tensor_cost[i] = op_cost;
240 inputs[i] = ij;
241 }
242
243 (op_cost, mem_cost)
244}
245
246#[inline]
254pub fn compute_memory_requirements(
255 inputs: &[Tensor],
256 contract_path: &ContractionPath,
257 memory_estimator: fn(&Tensor, &Tensor) -> f64,
258) -> f64 {
259 fn id(_: &Tensor, _: &Tensor) -> f64 {
260 0.0
261 }
262 let (_, mem) = contract_path_custom_cost(inputs, contract_path, id, memory_estimator);
263 mem
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 use rustc_hash::FxHashMap;
271
272 use crate::path;
273 use crate::tensornetwork::tensor::Tensor;
274
275 fn setup_simple() -> Tensor {
276 let bond_dims =
277 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
278 Tensor::new_composite(vec![
279 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
280 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
281 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
282 ])
283 }
284
285 fn setup_complex() -> Tensor {
286 let bond_dims = FxHashMap::from_iter([
287 (0, 5),
288 (1, 2),
289 (2, 6),
290 (3, 8),
291 (4, 1),
292 (5, 3),
293 (6, 4),
294 (7, 3),
295 (8, 2),
296 (9, 2),
297 ]);
298 let t1_tensors = vec![
299 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
300 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
301 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
302 ];
303 let t1 = Tensor::new_composite(t1_tensors);
304
305 let t2_tensors = vec![
306 Tensor::new_from_map(vec![5, 6, 8], &bond_dims),
307 Tensor::new_from_map(vec![7, 8, 9], &bond_dims),
308 ];
309 let t2 = Tensor::new_composite(t2_tensors);
310 Tensor::new_composite(vec![t1, t2])
311 }
312
313 fn setup_parallel() -> Tensor {
314 let bond_dims =
315 FxHashMap::from_iter([(0, 5), (1, 2), (2, 6), (3, 8), (4, 1), (5, 3), (6, 4)]);
316 Tensor::new_composite(vec![
317 Tensor::new_from_map(vec![4, 3, 2], &bond_dims),
318 Tensor::new_from_map(vec![0, 1, 3, 2], &bond_dims),
319 Tensor::new_from_map(vec![4, 5, 6], &bond_dims),
320 Tensor::new_from_map(vec![5, 6], &bond_dims),
321 ])
322 }
323
324 #[test]
325 fn test_contract_path_cost() {
326 let tn = setup_simple();
327 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], false);
328 assert_eq!(op_cost, 4540.);
329 assert_eq!(mem_cost, 538.);
330 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], false);
331 assert_eq!(op_cost, 49296.);
332 assert_eq!(mem_cost, 1176.);
333 }
334
335 #[test]
336 fn test_contract_complex_path_cost() {
337 let tn = setup_complex();
338 let (op_cost, mem_cost) = contract_path_cost(
339 tn.tensors(),
340 &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
341 false,
342 );
343 assert_eq!(op_cost, 11188.);
344 assert_eq!(mem_cost, 538.);
345 }
346
347 #[test]
348 fn test_contract_path_cost_only_ops() {
349 let tn = setup_simple();
350 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 1), (0, 2)], true);
351 assert_eq!(op_cost, 600.);
352 assert_eq!(mem_cost, 538.);
353 let (op_cost, mem_cost) = contract_path_cost(tn.tensors(), &path![(0, 2), (0, 1)], true);
354 assert_eq!(op_cost, 6336.);
355 assert_eq!(mem_cost, 1176.);
356 }
357
358 #[test]
359 fn test_contract_path_complex_cost_only_ops() {
360 let tn = setup_complex();
361 let (op_cost, mem_cost) = contract_path_cost(
362 tn.tensors(),
363 &path![{(0, [(0, 1), (0, 2)]), (1, [(0, 1)])}, (0, 1)],
364 true,
365 );
366 assert_eq!(op_cost, 1464.);
367 assert_eq!(mem_cost, 538.);
368 }
369
370 #[test]
371 fn test_communication_path_cost_only_ops() {
372 let tn = setup_parallel();
373 let (op_cost, mem_cost) =
374 communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 2)], true, true, None);
375 assert_eq!(op_cost, 490.);
376 assert_eq!(mem_cost, 538.);
377 }
378
379 #[test]
380 fn test_communication_path_cost() {
381 let tn = setup_parallel();
382 let (op_cost, mem_cost) =
383 communication_path_cost(tn.tensors(), &[(0, 1), (2, 3), (0, 1)], false, true, None);
384 assert_eq!(op_cost, 7564.);
385 assert_eq!(mem_cost, 538.);
386 }
387
388 #[test]
389 fn test_communication_path_cost_only_ops_with_partition_cost() {
390 let tn = setup_parallel();
391 let tensor_cost = vec![20., 30., 80., 10.];
392 let (op_cost, mem_cost) = communication_path_cost(
393 tn.tensors(),
394 &[(0, 1), (2, 3), (0, 2)],
395 true,
396 true,
397 Some(&tensor_cost),
398 );
399 assert_eq!(op_cost, 520.);
400 assert_eq!(mem_cost, 538.);
401 }
402
403 #[test]
404 fn test_communication_path_cost_with_partition_cost() {
405 let tn = setup_parallel();
406 let tensor_cost = vec![20., 30., 80., 10.];
407 let (op_cost, mem_cost) = communication_path_cost(
408 tn.tensors(),
409 &[(0, 1), (2, 3), (0, 1)],
410 false,
411 true,
412 Some(&tensor_cost),
413 );
414 assert_eq!(op_cost, 7594.);
415 assert_eq!(mem_cost, 538.);
416 }
417}