tnc/mpi/
communication.rs

1use log::debug;
2use mpi::topology::{Process, SimpleCommunicator};
3use mpi::traits::{BufferMut, Communicator, Destination, Root, Source};
4use mpi::Rank;
5
6use crate::contractionpath::{ContractionPath, SimplePath, SimplePathRef};
7use crate::mpi::mpi_types::{MessageBinaryBlob, RankTensorMapping};
8use crate::mpi::serialization::{deserialize, deserialize_tensor, serialize, serialize_tensor};
9use crate::tensornetwork::contraction::contract_tensor_network;
10use crate::tensornetwork::tensor::Tensor;
11
12/// Broadcasts a vector of `data` from `root` to all processes in `world`. For the
13/// receivers, `data` can just be an empty vector.
14fn broadcast_vec<T>(data: &mut Vec<T>, root: &Process)
15where
16    T: Clone + Default,
17    Vec<T>: BufferMut,
18{
19    // Broadcast length
20    let mut len = if root.is_self() { data.len() } else { 0 };
21    root.broadcast_into(&mut len);
22
23    // Broadcast data
24    if !root.is_self() {
25        data.resize(len, Default::default());
26    }
27    root.broadcast_into(data);
28}
29
30/// Broadcast a contraction index `path` from `root` to all processes in `world`. For
31/// the receivers, `path` can just be an empty slice.
32pub fn broadcast_path(path: &mut SimplePath, root: &Process) {
33    // Serialize path
34    let mut data = if root.is_self() {
35        serialize(&path)
36    } else {
37        Default::default()
38    };
39
40    // Broadcast data
41    broadcast_vec(&mut data, root);
42
43    // Deserialize path
44    if !root.is_self() {
45        *path = deserialize(&data);
46    }
47
48    debug!(path:serde; "Received broadcasted path");
49}
50
51/// Broadcast a value by serializing it and sending it as byte array.
52pub fn broadcast_serializing<T>(data: T, root: &Process) -> T
53where
54    T: serde::Serialize + serde::de::DeserializeOwned + Clone,
55{
56    let mut raw_value = if root.is_self() {
57        serialize(&data)
58    } else {
59        Default::default()
60    };
61
62    broadcast_vec(&mut raw_value, root);
63
64    if root.is_self() {
65        data
66    } else {
67        deserialize(&raw_value)
68    }
69}
70
71/// Sends the `tensor` to `receiver` via MPI.
72fn send_tensor(tensor: &Tensor, receiver: Rank, world: &SimpleCommunicator) {
73    let data = serialize_tensor(tensor);
74    world.process_at_rank(receiver).send(&data);
75}
76
77/// Receives a tensor from `sender` via MPI.
78fn receive_tensor(sender: Rank, world: &SimpleCommunicator) -> Tensor {
79    // Receive the buffer
80    let (data, _status) = world
81        .process_at_rank(sender)
82        .receive_vec::<MessageBinaryBlob>();
83
84    deserialize_tensor(&data)
85}
86
87/// Determines the tensor mapping for the given contraction `path`.
88/// Also returns the number of used ranks.
89fn get_tensor_mapping(path: &ContractionPath, size: Rank) -> RankTensorMapping {
90    let mut tensor_mapping = RankTensorMapping::with_capacity(size as usize);
91
92    let Some((final_tensor, _)) = path.toplevel.last() else {
93        // Empty path
94        return tensor_mapping;
95    };
96
97    // Reserve rank 0 for the final tensor
98    let mut used_ranks = 1;
99    for index in path.nested.keys() {
100        if index == final_tensor {
101            // Assign the final tensor to rank 0
102            tensor_mapping.insert(0, *index);
103        } else {
104            // Assign the next available rank to tensor `index`
105            tensor_mapping.insert(used_ranks, *index);
106            used_ranks += 1;
107        }
108    }
109    assert!(
110        used_ranks <= size,
111        "Not enough MPI ranks available, got {size} but need {used_ranks}!"
112    );
113    tensor_mapping
114}
115
116/// Information needed for communication during contraction of the tensor network.
117pub struct Communication {
118    /// A mapping between MPI ranks and their owned composite tensors. In slice
119    /// groups, only the slice root rank is assigned the tensor.
120    tensor_mapping: RankTensorMapping,
121}
122
123/// Distributes the partitioned tensor network to the various processes via MPI.
124pub fn scatter_tensor_network(
125    r_tn: &Tensor,
126    path: &ContractionPath,
127    rank: Rank,
128    size: Rank,
129    world: &SimpleCommunicator,
130) -> (Tensor, ContractionPath, Communication) {
131    debug!(rank, size; "Scattering tensor network");
132    let root = world.process_at_rank(0);
133
134    // Get information about used ranks
135    let tensor_mapping = if rank == 0 {
136        get_tensor_mapping(path, size)
137    } else {
138        Default::default()
139    };
140
141    // Tell the ranks the tensor they are responsible for (if any)
142    let tensor_mapping = broadcast_serializing(tensor_mapping, &root);
143    let is_tensor_owner = tensor_mapping.tensor(rank).is_some();
144    debug!(tensor_mapping:serde, is_tensor_owner; "Scattered organizational data");
145
146    // Send the local paths
147    let local_path = if rank == 0 {
148        debug!("Sending local paths");
149        let mut local_path = None;
150        for (i, local) in &path.nested {
151            let target_rank = tensor_mapping.rank(*i);
152            if target_rank == 0 {
153                // This is the path for the root, no need to send it
154                local_path = Some(local.clone());
155                continue;
156            }
157
158            world.process_at_rank(target_rank).send(&serialize(&local));
159        }
160        local_path.unwrap()
161    } else if is_tensor_owner {
162        debug!("Receiving local path");
163        let (raw_path, _status) = world.process_at_rank(0).receive_vec::<u8>();
164        deserialize(&raw_path)
165    } else {
166        Default::default()
167    };
168
169    // Send the tensors
170    let local_tn = if rank == 0 {
171        debug!("Sending tensors");
172        let mut local_tn = None;
173        for &(target_rank, tensor_index) in &tensor_mapping {
174            let tensor = r_tn.tensor(tensor_index);
175            if target_rank == 0 {
176                // This is the tensor for the root, no need to send it
177                local_tn = Some(tensor.clone());
178                continue;
179            }
180
181            send_tensor(tensor, target_rank, world);
182        }
183        local_tn.unwrap()
184    } else if is_tensor_owner {
185        debug!("Receiving tensor");
186        receive_tensor(0, world)
187    } else {
188        Default::default()
189    };
190    debug!("Scattered tensor network");
191
192    // Return the local tensor, path and communication information
193    (local_tn, local_path, Communication { tensor_mapping })
194}
195
196/// Uses the `path` as a communication blueprint to iteratively send tensors and contract them in a fan-in.
197/// Assumes that `path` is a valid contraction path.
198pub fn intermediate_reduce_tensor_network(
199    local_tn: &mut Tensor,
200    path: SimplePathRef,
201    rank: Rank,
202    world: &SimpleCommunicator,
203    communication: &Communication,
204) {
205    debug!(rank, path:serde; "Reducing tensor network (intermediate)");
206    assert!(local_tn.is_leaf());
207
208    let mut final_rank = 0;
209    for (x, y) in path {
210        let receiver = communication.tensor_mapping.rank(*x);
211        let sender = communication.tensor_mapping.rank(*y);
212        final_rank = receiver;
213        if receiver == rank {
214            // Receive tensor
215            debug!(sender; "Start receiving tensor");
216            let received_tensor = receive_tensor(sender, world);
217            debug!(sender; "Finish receiving tensor");
218
219            // Add local tensor and received tensor into a new tensor network
220            let tensor_network =
221                Tensor::new_composite(vec![std::mem::take(local_tn), received_tensor]);
222
223            // Contract tensors
224            let result = contract_tensor_network(tensor_network, &ContractionPath::single(0, 1));
225            *local_tn = result;
226        }
227        if sender == rank {
228            debug!(receiver; "Start sending tensor");
229            send_tensor(local_tn, receiver, world);
230            debug!(receiver; "Finish sending tensor");
231        }
232    }
233
234    // Only runs if the final contracted process is not process 0
235    if final_rank != 0 {
236        debug!(rank, final_rank; "Final rank is not 0");
237        if rank == 0 {
238            debug!(sender = final_rank; "Receiving final tensor");
239            let received_tensor = receive_tensor(final_rank, world);
240            *local_tn = received_tensor;
241        }
242        if rank == final_rank {
243            debug!(receiver = 0; "Sending final tensor");
244            send_tensor(local_tn, 0, world);
245        }
246    }
247    debug!("Reduced tensor network");
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    use crate::path;
255
256    #[test]
257    fn test_tensor_mapping() {
258        let path = path![
259            {
260            (0, [(0, 2), (0, 1)]),
261            (1, [(0, 1), (0, 1)]),
262            (2, [(0, 1)])
263            },
264            (0, 2),
265            (0, 1)
266        ];
267
268        let tensor_mapping = get_tensor_mapping(&path, 4);
269
270        assert_eq!(tensor_mapping.len(), 3);
271        assert_eq!(tensor_mapping.rank(0), 0);
272        assert_eq!(tensor_mapping.rank(1), 2);
273        assert_eq!(tensor_mapping.rank(2), 1);
274        assert_eq!(tensor_mapping.tensor(0), Some(0));
275        assert_eq!(tensor_mapping.tensor(1), Some(2));
276        assert_eq!(tensor_mapping.tensor(2), Some(1));
277        assert_eq!(tensor_mapping.tensor(3), None);
278    }
279}