tnc/contractionpath/
repartitioning.rs1use itertools::Itertools;
4use rand::Rng;
5use rustc_hash::FxHashMap;
6
7use crate::{
8 contractionpath::{
9 communication_schemes::CommunicationScheme,
10 contraction_cost::{communication_path_op_costs, contract_path_cost},
11 paths::{
12 cotengrust::{Cotengrust, OptMethod},
13 FindPath,
14 },
15 ContractionPath,
16 },
17 tensornetwork::{partitioning::partition_tensor_network, tensor::Tensor},
18};
19
20pub mod genetic;
21pub mod simulated_annealing;
22
23pub fn compute_solution<R>(
26 tensor: &Tensor,
27 partitioning: &[usize],
28 communication_scheme: CommunicationScheme,
29 rng: Option<&mut R>,
30) -> (Tensor, ContractionPath, f64, f64)
31where
32 R: Rng,
33{
34 let partitioned_tn = partition_tensor_network(tensor.clone(), partitioning);
36
37 let mut greedy = Cotengrust::new(&partitioned_tn, OptMethod::Greedy);
39 greedy.find_path();
40 let path = greedy.get_best_replace_path();
41
42 let mut latency_map =
44 FxHashMap::from_iter((0..partitioned_tn.tensors().len()).map(|i| (i, 0.0)));
45 for (i, local_path) in &path.nested {
46 let (local_cost, _) =
47 contract_path_cost(partitioned_tn.tensor(*i).tensors(), local_path, true);
48 latency_map.insert(*i, local_cost);
49 }
50
51 let children_tensors = partitioned_tn
53 .tensors()
54 .iter()
55 .map(Tensor::external_tensor)
56 .collect_vec();
57 let communication_path =
58 communication_scheme.communication_path(&children_tensors, &latency_map, rng);
59 let tensor_costs = (0..children_tensors.len())
60 .map(|i| latency_map[&i])
61 .collect_vec();
62 let ((parallel_cost, sum_cost), _) = communication_path_op_costs(
63 &children_tensors,
64 &communication_path,
65 true,
66 Some(&tensor_costs),
67 );
68
69 let final_path = ContractionPath {
71 nested: path.nested,
72 toplevel: communication_path,
73 };
74
75 (partitioned_tn, final_path, parallel_cost, sum_cost)
76}