tnc/contractionpath/
repartitioning.rs

1//! Methods for improving the partitioning of a tensor network to improve time-to-solution.
2
3use itertools::Itertools;
4use rand::Rng;
5use rustc_hash::FxHashMap;
6
7use crate::{
8    contractionpath::{
9        communication_schemes::CommunicationScheme,
10        contraction_cost::{communication_path_op_costs, contract_path_cost},
11        paths::{
12            cotengrust::{Cotengrust, OptMethod},
13            FindPath,
14        },
15        ContractionPath,
16    },
17    tensornetwork::{partitioning::partition_tensor_network, tensor::Tensor},
18};
19
20pub mod genetic;
21pub mod simulated_annealing;
22
23/// Given a `tensor` and a `partitioning` for it, this constructs the partitioned
24/// tensor and finds a contraction path for it.
25pub fn compute_solution<R>(
26    tensor: &Tensor,
27    partitioning: &[usize],
28    communication_scheme: CommunicationScheme,
29    rng: Option<&mut R>,
30) -> (Tensor, ContractionPath, f64, f64)
31where
32    R: Rng,
33{
34    // Partition the tensor network with the proposed solution
35    let partitioned_tn = partition_tensor_network(tensor.clone(), partitioning);
36
37    // Find contraction path
38    let mut greedy = Cotengrust::new(&partitioned_tn, OptMethod::Greedy);
39    greedy.find_path();
40    let path = greedy.get_best_replace_path();
41
42    // Store the local paths (and costs)
43    let mut latency_map =
44        FxHashMap::from_iter((0..partitioned_tn.tensors().len()).map(|i| (i, 0.0)));
45    for (i, local_path) in &path.nested {
46        let (local_cost, _) =
47            contract_path_cost(partitioned_tn.tensor(*i).tensors(), local_path, true);
48        latency_map.insert(*i, local_cost);
49    }
50
51    // Find communication path separately
52    let children_tensors = partitioned_tn
53        .tensors()
54        .iter()
55        .map(Tensor::external_tensor)
56        .collect_vec();
57    let communication_path =
58        communication_scheme.communication_path(&children_tensors, &latency_map, rng);
59    let tensor_costs = (0..children_tensors.len())
60        .map(|i| latency_map[&i])
61        .collect_vec();
62    let ((parallel_cost, sum_cost), _) = communication_path_op_costs(
63        &children_tensors,
64        &communication_path,
65        true,
66        Some(&tensor_costs),
67    );
68
69    // Add the communication path to the local paths
70    let final_path = ContractionPath {
71        nested: path.nested,
72        toplevel: communication_path,
73    };
74
75    (partitioned_tn, final_path, parallel_cost, sum_cost)
76}