1use log::debug;
2use mpi::topology::{Process, SimpleCommunicator};
3use mpi::traits::{BufferMut, Communicator, Destination, Root, Source};
4use mpi::Rank;
5
6use crate::contractionpath::{ContractionPath, SimplePath, SimplePathRef};
7use crate::mpi::mpi_types::{MessageBinaryBlob, RankTensorMapping};
8use crate::mpi::serialization::{deserialize, deserialize_tensor, serialize, serialize_tensor};
9use crate::tensornetwork::contraction::contract_tensor_network;
10use crate::tensornetwork::tensor::Tensor;
11
12fn broadcast_vec<T>(data: &mut Vec<T>, root: &Process)
15where
16 T: Clone + Default,
17 Vec<T>: BufferMut,
18{
19 let mut len = if root.is_self() { data.len() } else { 0 };
21 root.broadcast_into(&mut len);
22
23 if !root.is_self() {
25 data.resize(len, Default::default());
26 }
27 root.broadcast_into(data);
28}
29
30pub fn broadcast_path(path: &mut SimplePath, root: &Process) {
33 let mut data = if root.is_self() {
35 serialize(&path)
36 } else {
37 Default::default()
38 };
39
40 broadcast_vec(&mut data, root);
42
43 if !root.is_self() {
45 *path = deserialize(&data);
46 }
47
48 debug!(path:serde; "Received broadcasted path");
49}
50
51pub fn broadcast_serializing<T>(data: T, root: &Process) -> T
53where
54 T: serde::Serialize + serde::de::DeserializeOwned + Clone,
55{
56 let mut raw_value = if root.is_self() {
57 serialize(&data)
58 } else {
59 Default::default()
60 };
61
62 broadcast_vec(&mut raw_value, root);
63
64 if root.is_self() {
65 data
66 } else {
67 deserialize(&raw_value)
68 }
69}
70
71fn send_tensor(tensor: &Tensor, receiver: Rank, world: &SimpleCommunicator) {
73 let data = serialize_tensor(tensor);
74 world.process_at_rank(receiver).send(&data);
75}
76
77fn receive_tensor(sender: Rank, world: &SimpleCommunicator) -> Tensor {
79 let (data, _status) = world
81 .process_at_rank(sender)
82 .receive_vec::<MessageBinaryBlob>();
83
84 deserialize_tensor(&data)
85}
86
87fn get_tensor_mapping(path: &ContractionPath, size: Rank) -> RankTensorMapping {
90 let mut tensor_mapping = RankTensorMapping::with_capacity(size as usize);
91
92 let Some((final_tensor, _)) = path.toplevel.last() else {
93 tensor_mapping.insert(0, 0);
95 return tensor_mapping;
96 };
97
98 let mut used_ranks = 1;
100 for index in path.nested.keys() {
101 if index == final_tensor {
102 tensor_mapping.insert(0, *index);
104 } else {
105 tensor_mapping.insert(used_ranks, *index);
107 used_ranks += 1;
108 }
109 }
110 assert!(
111 used_ranks <= size,
112 "Not enough MPI ranks available, got {size} but need {used_ranks}!"
113 );
114 tensor_mapping
115}
116
117pub struct Communication {
119 tensor_mapping: RankTensorMapping,
122}
123
124pub fn scatter_tensor_network(
126 r_tn: &Tensor,
127 path: &ContractionPath,
128 rank: Rank,
129 size: Rank,
130 world: &SimpleCommunicator,
131) -> (Tensor, ContractionPath, Communication) {
132 debug!(rank, size; "Scattering tensor network");
133 let root = world.process_at_rank(0);
134
135 let tensor_mapping = if rank == 0 {
137 get_tensor_mapping(path, size)
138 } else {
139 Default::default()
140 };
141
142 let tensor_mapping = broadcast_serializing(tensor_mapping, &root);
144 let is_tensor_owner = tensor_mapping.tensor(rank).is_some();
145 debug!(tensor_mapping:serde, is_tensor_owner; "Scattered organizational data");
146
147 let local_path = if rank == 0 {
149 debug!("Sending local paths");
150 let mut local_path = None;
151 for (i, local) in &path.nested {
152 let target_rank = tensor_mapping.rank(*i);
153 if target_rank == 0 {
154 local_path = Some(local.clone());
156 continue;
157 }
158
159 world.process_at_rank(target_rank).send(&serialize(&local));
160 }
161 local_path.unwrap()
162 } else if is_tensor_owner {
163 debug!("Receiving local path");
164 let (raw_path, _status) = world.process_at_rank(0).receive_vec::<u8>();
165 deserialize(&raw_path)
166 } else {
167 Default::default()
168 };
169
170 let local_tn = if rank == 0 {
172 debug!("Sending tensors");
173 let mut local_tn = None;
174 for &(target_rank, tensor_index) in &tensor_mapping {
175 let tensor = r_tn.tensor(tensor_index);
176 if target_rank == 0 {
177 local_tn = Some(tensor.clone());
179 continue;
180 }
181
182 send_tensor(tensor, target_rank, world);
183 }
184 local_tn.unwrap()
185 } else if is_tensor_owner {
186 debug!("Receiving tensor");
187 receive_tensor(0, world)
188 } else {
189 Default::default()
190 };
191 debug!("Scattered tensor network");
192
193 (local_tn, local_path, Communication { tensor_mapping })
195}
196
197pub fn intermediate_reduce_tensor_network(
200 local_tn: &mut Tensor,
201 path: SimplePathRef,
202 rank: Rank,
203 world: &SimpleCommunicator,
204 communication: &Communication,
205) {
206 debug!(rank, path:serde; "Reducing tensor network (intermediate)");
207 assert!(local_tn.is_leaf());
208
209 let mut final_rank = 0;
210 for (x, y) in path {
211 let receiver = communication.tensor_mapping.rank(*x);
212 let sender = communication.tensor_mapping.rank(*y);
213 final_rank = receiver;
214 if receiver == rank {
215 debug!(sender; "Start receiving tensor");
217 let received_tensor = receive_tensor(sender, world);
218 debug!(sender; "Finish receiving tensor");
219
220 let tensor_network =
222 Tensor::new_composite(vec![std::mem::take(local_tn), received_tensor]);
223
224 let result = contract_tensor_network(tensor_network, &ContractionPath::single(0, 1));
226 *local_tn = result;
227 }
228 if sender == rank {
229 debug!(receiver; "Start sending tensor");
230 send_tensor(local_tn, receiver, world);
231 debug!(receiver; "Finish sending tensor");
232 }
233 }
234
235 if final_rank != 0 {
237 debug!(rank, final_rank; "Final rank is not 0");
238 if rank == 0 {
239 debug!(sender = final_rank; "Receiving final tensor");
240 let received_tensor = receive_tensor(final_rank, world);
241 *local_tn = received_tensor;
242 }
243 if rank == final_rank {
244 debug!(receiver = 0; "Sending final tensor");
245 send_tensor(local_tn, 0, world);
246 }
247 }
248 debug!("Reduced tensor network");
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 use crate::path;
256
257 #[test]
258 fn test_tensor_mapping() {
259 let path = path![
260 {
261 (0, [(0, 2), (0, 1)]),
262 (1, [(0, 1), (0, 1)]),
263 (2, [(0, 1)])
264 },
265 (0, 2),
266 (0, 1)
267 ];
268
269 let tensor_mapping = get_tensor_mapping(&path, 4);
270
271 assert_eq!(tensor_mapping.len(), 3);
272 assert_eq!(tensor_mapping.rank(0), 0);
273 assert_eq!(tensor_mapping.rank(1), 2);
274 assert_eq!(tensor_mapping.rank(2), 1);
275 assert_eq!(tensor_mapping.tensor(0), Some(0));
276 assert_eq!(tensor_mapping.tensor(1), Some(2));
277 assert_eq!(tensor_mapping.tensor(2), Some(1));
278 assert_eq!(tensor_mapping.tensor(3), None);
279 }
280}