1use log::debug;
2use mpi::topology::{Process, SimpleCommunicator};
3use mpi::traits::{BufferMut, Communicator, Destination, Root, Source};
4use mpi::Rank;
5
6use crate::contractionpath::{ContractionPath, SimplePath, SimplePathRef};
7use crate::mpi::mpi_types::{MessageBinaryBlob, RankTensorMapping};
8use crate::mpi::serialization::{deserialize, deserialize_tensor, serialize, serialize_tensor};
9use crate::tensornetwork::contraction::contract_tensor_network;
10use crate::tensornetwork::tensor::Tensor;
11
12fn broadcast_vec<T>(data: &mut Vec<T>, root: &Process)
15where
16 T: Clone + Default,
17 Vec<T>: BufferMut,
18{
19 let mut len = if root.is_self() { data.len() } else { 0 };
21 root.broadcast_into(&mut len);
22
23 if !root.is_self() {
25 data.resize(len, Default::default());
26 }
27 root.broadcast_into(data);
28}
29
30pub fn broadcast_path(path: &mut SimplePath, root: &Process) {
33 let mut data = if root.is_self() {
35 serialize(&path)
36 } else {
37 Default::default()
38 };
39
40 broadcast_vec(&mut data, root);
42
43 if !root.is_self() {
45 *path = deserialize(&data);
46 }
47
48 debug!(path:serde; "Received broadcasted path");
49}
50
51pub fn broadcast_serializing<T>(data: T, root: &Process) -> T
53where
54 T: serde::Serialize + serde::de::DeserializeOwned + Clone,
55{
56 let mut raw_value = if root.is_self() {
57 serialize(&data)
58 } else {
59 Default::default()
60 };
61
62 broadcast_vec(&mut raw_value, root);
63
64 if root.is_self() {
65 data
66 } else {
67 deserialize(&raw_value)
68 }
69}
70
71fn send_tensor(tensor: &Tensor, receiver: Rank, world: &SimpleCommunicator) {
73 let data = serialize_tensor(tensor);
74 world.process_at_rank(receiver).send(&data);
75}
76
77fn receive_tensor(sender: Rank, world: &SimpleCommunicator) -> Tensor {
79 let (data, _status) = world
81 .process_at_rank(sender)
82 .receive_vec::<MessageBinaryBlob>();
83
84 deserialize_tensor(&data)
85}
86
87fn get_tensor_mapping(path: &ContractionPath, size: Rank) -> RankTensorMapping {
90 let mut tensor_mapping = RankTensorMapping::with_capacity(size as usize);
91
92 let Some((final_tensor, _)) = path.toplevel.last() else {
93 return tensor_mapping;
95 };
96
97 let mut used_ranks = 1;
99 for index in path.nested.keys() {
100 if index == final_tensor {
101 tensor_mapping.insert(0, *index);
103 } else {
104 tensor_mapping.insert(used_ranks, *index);
106 used_ranks += 1;
107 }
108 }
109 assert!(
110 used_ranks <= size,
111 "Not enough MPI ranks available, got {size} but need {used_ranks}!"
112 );
113 tensor_mapping
114}
115
116pub struct Communication {
118 tensor_mapping: RankTensorMapping,
121}
122
123pub fn scatter_tensor_network(
125 r_tn: &Tensor,
126 path: &ContractionPath,
127 rank: Rank,
128 size: Rank,
129 world: &SimpleCommunicator,
130) -> (Tensor, ContractionPath, Communication) {
131 debug!(rank, size; "Scattering tensor network");
132 let root = world.process_at_rank(0);
133
134 let tensor_mapping = if rank == 0 {
136 get_tensor_mapping(path, size)
137 } else {
138 Default::default()
139 };
140
141 let tensor_mapping = broadcast_serializing(tensor_mapping, &root);
143 let is_tensor_owner = tensor_mapping.tensor(rank).is_some();
144 debug!(tensor_mapping:serde, is_tensor_owner; "Scattered organizational data");
145
146 let local_path = if rank == 0 {
148 debug!("Sending local paths");
149 let mut local_path = None;
150 for (i, local) in &path.nested {
151 let target_rank = tensor_mapping.rank(*i);
152 if target_rank == 0 {
153 local_path = Some(local.clone());
155 continue;
156 }
157
158 world.process_at_rank(target_rank).send(&serialize(&local));
159 }
160 local_path.unwrap()
161 } else if is_tensor_owner {
162 debug!("Receiving local path");
163 let (raw_path, _status) = world.process_at_rank(0).receive_vec::<u8>();
164 deserialize(&raw_path)
165 } else {
166 Default::default()
167 };
168
169 let local_tn = if rank == 0 {
171 debug!("Sending tensors");
172 let mut local_tn = None;
173 for &(target_rank, tensor_index) in &tensor_mapping {
174 let tensor = r_tn.tensor(tensor_index);
175 if target_rank == 0 {
176 local_tn = Some(tensor.clone());
178 continue;
179 }
180
181 send_tensor(tensor, target_rank, world);
182 }
183 local_tn.unwrap()
184 } else if is_tensor_owner {
185 debug!("Receiving tensor");
186 receive_tensor(0, world)
187 } else {
188 Default::default()
189 };
190 debug!("Scattered tensor network");
191
192 (local_tn, local_path, Communication { tensor_mapping })
194}
195
196pub fn intermediate_reduce_tensor_network(
199 local_tn: &mut Tensor,
200 path: SimplePathRef,
201 rank: Rank,
202 world: &SimpleCommunicator,
203 communication: &Communication,
204) {
205 debug!(rank, path:serde; "Reducing tensor network (intermediate)");
206 assert!(local_tn.is_leaf());
207
208 let mut final_rank = 0;
209 for (x, y) in path {
210 let receiver = communication.tensor_mapping.rank(*x);
211 let sender = communication.tensor_mapping.rank(*y);
212 final_rank = receiver;
213 if receiver == rank {
214 debug!(sender; "Start receiving tensor");
216 let received_tensor = receive_tensor(sender, world);
217 debug!(sender; "Finish receiving tensor");
218
219 let tensor_network =
221 Tensor::new_composite(vec![std::mem::take(local_tn), received_tensor]);
222
223 let result = contract_tensor_network(tensor_network, &ContractionPath::single(0, 1));
225 *local_tn = result;
226 }
227 if sender == rank {
228 debug!(receiver; "Start sending tensor");
229 send_tensor(local_tn, receiver, world);
230 debug!(receiver; "Finish sending tensor");
231 }
232 }
233
234 if final_rank != 0 {
236 debug!(rank, final_rank; "Final rank is not 0");
237 if rank == 0 {
238 debug!(sender = final_rank; "Receiving final tensor");
239 let received_tensor = receive_tensor(final_rank, world);
240 *local_tn = received_tensor;
241 }
242 if rank == final_rank {
243 debug!(receiver = 0; "Sending final tensor");
244 send_tensor(local_tn, 0, world);
245 }
246 }
247 debug!("Reduced tensor network");
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 use crate::path;
255
256 #[test]
257 fn test_tensor_mapping() {
258 let path = path![
259 {
260 (0, [(0, 2), (0, 1)]),
261 (1, [(0, 1), (0, 1)]),
262 (2, [(0, 1)])
263 },
264 (0, 2),
265 (0, 1)
266 ];
267
268 let tensor_mapping = get_tensor_mapping(&path, 4);
269
270 assert_eq!(tensor_mapping.len(), 3);
271 assert_eq!(tensor_mapping.rank(0), 0);
272 assert_eq!(tensor_mapping.rank(1), 2);
273 assert_eq!(tensor_mapping.rank(2), 1);
274 assert_eq!(tensor_mapping.tensor(0), Some(0));
275 assert_eq!(tensor_mapping.tensor(1), Some(2));
276 assert_eq!(tensor_mapping.tensor(2), Some(1));
277 assert_eq!(tensor_mapping.tensor(3), None);
278 }
279}