1use std::{
4 iter::zip,
5 time::{Duration, Instant},
6};
7
8use itertools::Itertools;
9use ordered_float::NotNan;
10use rand::{rngs::StdRng, Rng, SeedableRng};
11use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
12use rustc_hash::FxHashSet;
13
14use crate::{
15 contractionpath::{
16 communication_schemes::CommunicationScheme,
17 contraction_cost::{compute_memory_requirements, contract_size_tensors_exact},
18 paths::{
19 cotengrust::{Cotengrust, OptMethod},
20 FindPath,
21 },
22 repartitioning::compute_solution,
23 SimplePath,
24 },
25 tensornetwork::tensor::Tensor,
26};
27
28type ScoreType = NotNan<f64>;
29
30const PROCESSING_THREADS: usize = 48;
33
34pub trait OptModel<'a>: Sync + Send {
36 type SolutionType: Clone + Sync + Send;
38
39 fn generate_trial_solution<R: Rng>(
41 &self,
42 current_solution: Self::SolutionType,
43 rng: &mut R,
44 ) -> Self::SolutionType;
45
46 fn evaluate<R: Rng>(&self, solution: &Self::SolutionType, rng: &mut R) -> ScoreType;
48}
49
50#[derive(Clone, Copy)]
52pub struct SimulatedAnnealingOptimizer {
53 n_trials: usize,
55 max_time: Duration,
57 n_steps: usize,
59 restart_iter: usize,
62 initial_temperature: f64,
64 final_temperature: f64,
66}
67
68#[inline]
72fn linear_interpolation(start: f64, end: f64, t: f64) -> f64 {
73 (end - start).mul_add(t, start)
74}
75
76impl<'a> SimulatedAnnealingOptimizer {
77 #[allow(clippy::too_many_arguments)]
83 fn optimize_with_temperature<M, R>(
84 &self,
85 model: &M,
86 initial_solution: M::SolutionType,
87 rng: &mut R,
88 ) -> (M::SolutionType, ScoreType)
89 where
90 M: OptModel<'a>,
91 R: Rng,
92 {
93 let mut current_score = model.evaluate(&initial_solution, rng);
94 let mut current_solution = initial_solution;
95 let mut best_solution = current_solution.clone();
96 let mut best_score = current_score;
97 let mut last_improvement = 0;
98 let steps_per_thread = self.n_steps.div_ceil(self.n_trials);
99
100 let log_start = self.initial_temperature.log2();
101 let log_end = self.final_temperature.log2();
102 let total_seconds = self.max_time.as_secs_f64();
103 let mut temperature = self.initial_temperature;
104 let mut rngs = (0..self.n_trials)
105 .map(|_| StdRng::seed_from_u64(rng.random()))
106 .collect_vec();
107 let end_time = Instant::now() + self.max_time;
108 loop {
109 let (_, trial_solution, trial_score) = rngs
111 .par_iter_mut()
112 .enumerate()
113 .map(|(index, rng)| {
114 let mut trial_score = current_score;
115 let mut trial_solution = current_solution.clone();
116 for _ in 0..steps_per_thread {
117 let solution = model.generate_trial_solution(trial_solution.clone(), rng);
118 let score = model.evaluate(&solution, rng);
119
120 let diff = (score / trial_score).log2();
121 let acceptance_probability = (-diff / temperature).exp();
122 let random_value = rng.random();
123
124 if acceptance_probability >= random_value {
125 trial_solution = solution;
126 trial_score = score;
127 }
128 }
129 (index, trial_solution, trial_score)
130 })
131 .min_by_key(|(index, _, score)| (*score, *index))
132 .unwrap();
133
134 current_score = trial_score;
135 current_solution = trial_solution;
136
137 if current_score < best_score {
139 best_solution = current_solution.clone();
140 best_score = current_score;
141 last_improvement = 0;
142 }
143
144 last_improvement += 1;
145
146 if last_improvement == self.restart_iter {
148 current_solution = best_solution.clone();
149 current_score = best_score;
150 }
151
152 let now = Instant::now();
154 if now > end_time {
155 break;
157 }
158 let remaining_time = (end_time - now).as_secs_f64();
159 let progress = 1.0 - remaining_time / total_seconds;
160 temperature = linear_interpolation(log_start, log_end, progress).exp2();
161 }
162
163 (best_solution, best_score)
164 }
165}
166
167pub struct NaivePartitioningModel<'a> {
169 pub tensor: &'a Tensor,
170 pub num_partitions: usize,
171 pub communication_scheme: CommunicationScheme,
172 pub memory_limit: Option<f64>,
173}
174
175impl<'a> OptModel<'a> for NaivePartitioningModel<'a> {
176 type SolutionType = Vec<usize>;
177
178 fn generate_trial_solution<R: Rng>(
179 &self,
180 mut current_solution: Self::SolutionType,
181 rng: &mut R,
182 ) -> Self::SolutionType {
183 let tensor_index = rng.random_range(0..current_solution.len());
184 let current_partition = current_solution[tensor_index];
185 let new_partition = loop {
186 let b = rng.random_range(0..self.num_partitions);
187 if b != current_partition {
188 break b;
189 }
190 };
191 current_solution[tensor_index] = new_partition;
192 current_solution
193 }
194
195 fn evaluate<R: Rng>(&self, solution: &Self::SolutionType, rng: &mut R) -> ScoreType {
196 let (partitioned_tn, path, parallel_cost, _) =
198 compute_solution(self.tensor, solution, self.communication_scheme, Some(rng));
199
200 let mem = compute_memory_requirements(
202 partitioned_tn.tensors(),
203 &path,
204 contract_size_tensors_exact,
205 );
206
207 if self.memory_limit.is_some_and(|limit| mem > limit) {
209 unsafe { NotNan::new_unchecked(f64::INFINITY) }
210 } else {
211 NotNan::new(parallel_cost).unwrap()
212 }
213 }
214}
215
216pub struct NaiveIntermediatePartitioningModel<'a> {
218 pub tensor: &'a Tensor,
219 pub num_partitions: usize,
220 pub communication_scheme: CommunicationScheme,
221 pub memory_limit: Option<f64>,
222}
223
224impl<'a> OptModel<'a> for NaiveIntermediatePartitioningModel<'a> {
225 type SolutionType = (Vec<usize>, Vec<SimplePath>);
226
227 fn generate_trial_solution<R: Rng>(
228 &self,
229 current_solution: Self::SolutionType,
230 rng: &mut R,
231 ) -> Self::SolutionType {
232 let (mut partitioning, mut contraction_paths) = current_solution;
233
234 let viable_partitions = contraction_paths
236 .iter()
237 .enumerate()
238 .filter_map(|(contraction_id, contraction)| {
239 if contraction.len() >= 3 {
240 Some(contraction_id)
241 } else {
242 None
243 }
244 })
245 .collect_vec();
246
247 if viable_partitions.is_empty() {
248 return (partitioning, contraction_paths);
250 }
251 let trial = rng.random_range(0..viable_partitions.len());
252 let source_partition = viable_partitions[trial];
253
254 let pair_index = rng.random_range(0..contraction_paths[source_partition].len() - 1);
256 let (i, j) = contraction_paths[source_partition][pair_index];
257 let mut tensor_leaves = FxHashSet::from_iter([i, j]);
258
259 for (i, j) in contraction_paths[source_partition]
261 .iter()
262 .take(pair_index)
263 .rev()
264 {
265 if tensor_leaves.contains(i) {
266 tensor_leaves.insert(*j);
267 }
268 }
269
270 let mut shifted_indices = Vec::with_capacity(tensor_leaves.len());
271 for (partition_tensor_index, (i, _partition)) in partitioning
272 .iter()
273 .enumerate()
274 .filter(|(_, partition)| *partition == &source_partition)
275 .enumerate()
276 {
277 if tensor_leaves.contains(&partition_tensor_index) {
278 shifted_indices.push(i);
279 }
280 }
281
282 let target_partition = loop {
284 let b = rng.random_range(0..self.num_partitions);
285 if b != source_partition {
286 break b;
287 }
288 };
289
290 for index in shifted_indices {
292 partitioning[index] = target_partition;
293 }
294
295 let mut from_tensor = Tensor::default();
297 let mut to_tensor = Tensor::default();
298 for (partition_index, tensor) in zip(&partitioning, self.tensor.tensors()) {
299 if *partition_index == source_partition {
300 from_tensor.push_tensor(tensor.clone());
301 } else if *partition_index == target_partition {
302 to_tensor.push_tensor(tensor.clone());
303 }
304 }
305
306 let mut from_opt = Cotengrust::new(&from_tensor, OptMethod::Greedy);
307 from_opt.find_path();
308 let from_path = from_opt.get_best_replace_path();
309 contraction_paths[source_partition] = from_path.into_simple();
310
311 let mut to_opt = Cotengrust::new(&to_tensor, OptMethod::Greedy);
312 to_opt.find_path();
313 let to_path = to_opt.get_best_replace_path();
314 contraction_paths[target_partition] = to_path.into_simple();
315
316 (partitioning, contraction_paths)
317 }
318
319 fn evaluate<R: Rng>(&self, solution: &Self::SolutionType, rng: &mut R) -> ScoreType {
320 let (partitioned_tn, path, parallel_cost, _) = compute_solution(
322 self.tensor,
323 &solution.0,
324 self.communication_scheme,
325 Some(rng),
326 );
327
328 let mem = compute_memory_requirements(
330 partitioned_tn.tensors(),
331 &path,
332 contract_size_tensors_exact,
333 );
334
335 if self.memory_limit.is_some_and(|limit| mem > limit) {
337 unsafe { NotNan::new_unchecked(f64::INFINITY) }
338 } else {
339 NotNan::new(parallel_cost).unwrap()
340 }
341 }
342}
343
344pub struct LeafPartitioningModel<'a> {
347 pub tensor: &'a Tensor,
348 pub communication_scheme: CommunicationScheme,
349 pub memory_limit: Option<f64>,
350}
351
352impl<'a> OptModel<'a> for LeafPartitioningModel<'a> {
353 type SolutionType = (Vec<usize>, Vec<Tensor>);
354
355 fn generate_trial_solution<R: Rng>(
356 &self,
357 current_solution: Self::SolutionType,
358 rng: &mut R,
359 ) -> Self::SolutionType {
360 let (mut partitioning, mut partition_tensors) = current_solution;
361 let tensor_index = rng.random_range(0..partitioning.len());
362 let shifted_tensor = self.tensor.tensor(tensor_index);
363 let source_partition = partitioning[tensor_index];
364
365 let (new_partition, _) = partition_tensors
366 .iter()
367 .enumerate()
368 .filter_map(|(i, partition_tensor)| {
369 if i != source_partition {
370 Some((
371 i,
372 (shifted_tensor ^ partition_tensor).size() - partition_tensor.size(),
373 ))
374 } else {
375 None
377 }
378 })
379 .min_by(|a, b| a.1.total_cmp(&b.1))
380 .unwrap();
381
382 partitioning[tensor_index] = new_partition;
383 partition_tensors[source_partition] ^= shifted_tensor;
384 partition_tensors[new_partition] ^= shifted_tensor;
385 (partitioning, partition_tensors)
386 }
387
388 fn evaluate<R: Rng>(&self, solution: &Self::SolutionType, rng: &mut R) -> ScoreType {
389 let (partitioned_tn, path, parallel_cost, _) = compute_solution(
391 self.tensor,
392 &solution.0,
393 self.communication_scheme,
394 Some(rng),
395 );
396
397 let mem = compute_memory_requirements(
399 partitioned_tn.tensors(),
400 &path,
401 contract_size_tensors_exact,
402 );
403
404 if self.memory_limit.is_some_and(|limit| mem > limit) {
406 unsafe { NotNan::new_unchecked(f64::INFINITY) }
407 } else {
408 NotNan::new(parallel_cost).unwrap()
409 }
410 }
411}
412
413pub struct IntermediatePartitioningModel<'a> {
416 pub tensor: &'a Tensor,
417 pub communication_scheme: CommunicationScheme,
418 pub memory_limit: Option<f64>,
419}
420
421impl<'a> OptModel<'a> for IntermediatePartitioningModel<'a> {
422 type SolutionType = (Vec<usize>, Vec<Tensor>, Vec<SimplePath>);
423
424 fn generate_trial_solution<R: Rng>(
425 &self,
426 current_solution: Self::SolutionType,
427 rng: &mut R,
428 ) -> Self::SolutionType {
429 let (mut partitioning, mut partition_tensors, mut contraction_paths) = current_solution;
430
431 let viable_partitions = contraction_paths
433 .iter()
434 .enumerate()
435 .filter_map(|(contraction_id, contraction)| {
436 if contraction.len() >= 3 {
437 Some(contraction_id)
438 } else {
439 None
440 }
441 })
442 .collect_vec();
443
444 if viable_partitions.is_empty() {
445 return (partitioning, partition_tensors, contraction_paths);
447 }
448 let trial = rng.random_range(0..viable_partitions.len());
449 let source_partition = viable_partitions[trial];
450
451 let pair_index = rng.random_range(0..contraction_paths[source_partition].len() - 1);
453 let (i, j) = contraction_paths[source_partition][pair_index];
454 let mut tensor_leaves = FxHashSet::from_iter([i, j]);
455
456 for (i, j) in contraction_paths[source_partition]
458 .iter()
459 .take(pair_index)
460 .rev()
461 {
462 if tensor_leaves.contains(i) {
463 tensor_leaves.insert(*j);
464 }
465 }
466
467 let mut shifted_tensor = Tensor::default();
468 let mut shifted_indices = Vec::with_capacity(tensor_leaves.len());
469 for (partition_tensor_index, (i, _partition)) in partitioning
470 .iter()
471 .enumerate()
472 .filter(|(_, partition)| *partition == &source_partition)
473 .enumerate()
474 {
475 if tensor_leaves.contains(&partition_tensor_index) {
476 shifted_tensor ^= self.tensor.tensor(i);
477 shifted_indices.push(i);
478 }
479 }
480
481 let (target_partition, _) = partition_tensors
484 .iter()
485 .enumerate()
486 .filter_map(|(i, partition_tensor)| {
487 if i != source_partition {
488 Some((
489 i,
490 (&shifted_tensor ^ partition_tensor).size() - partition_tensor.size(),
491 ))
492 } else {
493 None
495 }
496 })
497 .min_by(|a, b| a.1.total_cmp(&b.1))
498 .unwrap();
499
500 for index in shifted_indices {
502 partitioning[index] = target_partition;
503 }
504
505 partition_tensors[source_partition] ^= &shifted_tensor;
507 partition_tensors[target_partition] ^= &shifted_tensor;
508
509 let mut from_tensor = Tensor::default();
511 let mut to_tensor = Tensor::default();
512 for (partition_index, tensor) in zip(&partitioning, self.tensor.tensors()) {
513 if *partition_index == source_partition {
514 from_tensor.push_tensor(tensor.clone());
515 } else if *partition_index == target_partition {
516 to_tensor.push_tensor(tensor.clone());
517 }
518 }
519
520 let mut from_opt = Cotengrust::new(&from_tensor, OptMethod::Greedy);
521 from_opt.find_path();
522 let from_path = from_opt.get_best_replace_path();
523 contraction_paths[source_partition] = from_path.into_simple();
524
525 let mut to_opt = Cotengrust::new(&to_tensor, OptMethod::Greedy);
526 to_opt.find_path();
527 let to_path = to_opt.get_best_replace_path();
528 contraction_paths[target_partition] = to_path.into_simple();
529
530 (partitioning, partition_tensors, contraction_paths)
531 }
532
533 fn evaluate<R: Rng>(&self, solution: &Self::SolutionType, rng: &mut R) -> ScoreType {
534 let (partitioned_tn, path, parallel_cost, _) = compute_solution(
536 self.tensor,
537 &solution.0,
538 self.communication_scheme,
539 Some(rng),
540 );
541
542 let mem = compute_memory_requirements(
544 partitioned_tn.tensors(),
545 &path,
546 contract_size_tensors_exact,
547 );
548
549 if self.memory_limit.is_some_and(|limit| mem > limit) {
551 unsafe { NotNan::new_unchecked(f64::INFINITY) }
552 } else {
553 NotNan::new(parallel_cost).unwrap()
554 }
555 }
556}
557
558pub fn balance_partitions<'a, R, M>(
560 model: M,
561 initial_solution: M::SolutionType,
562 rng: &mut R,
563 max_time: Duration,
564) -> (M::SolutionType, ScoreType)
565where
566 R: Rng,
567 M: OptModel<'a>,
568{
569 let optimizer = SimulatedAnnealingOptimizer {
570 n_trials: PROCESSING_THREADS,
571 max_time,
572 n_steps: PROCESSING_THREADS * 10,
573 restart_iter: 50,
574 initial_temperature: 2.0,
575 final_temperature: 0.05,
576 };
577 optimizer.optimize_with_temperature::<M, _>(&model, initial_solution, rng)
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 use float_cmp::assert_approx_eq;
585
586 #[test]
587 fn simple_linear_interpolation() {
588 assert_approx_eq!(f64, linear_interpolation(0., 6., 0.5), 3.0);
589 assert_approx_eq!(f64, linear_interpolation(-1.0, 4.0, 0.2), 0.0);
590 assert_approx_eq!(f64, linear_interpolation(-7.0, -6.0, 0.0), -7.0);
591 assert_approx_eq!(f64, linear_interpolation(3.0, 5.0, 1.0), 5.0);
592 }
593
594 #[test]
595 fn small_leaf_partitioning() {
596 let t1 = Tensor::new_from_const(vec![0, 1], 2);
597 let t2 = Tensor::new_from_const(vec![2, 3], 2);
598 let t3 = Tensor::new_from_const(vec![0, 1, 4], 2);
599 let t4 = Tensor::new_from_const(vec![2, 3, 4], 2);
600 let tn = Tensor::new_composite(vec![t1.clone(), t2.clone(), t3.clone(), t4.clone()]);
601 let tn1 = Tensor::new_composite(vec![t1, t2]);
602 let tn2 = Tensor::new_composite(vec![t3, t4]);
603 let initial_partitioning = vec![0, 0, 1, 1];
604 let initial_partitions = vec![tn1, tn2];
605 let mut rng = StdRng::seed_from_u64(42);
606
607 let ((partitioning, _partitions), _) = balance_partitions(
608 LeafPartitioningModel {
609 tensor: &tn,
610 communication_scheme: CommunicationScheme::Greedy,
611 memory_limit: None,
612 },
613 (initial_partitioning, initial_partitions),
614 &mut rng,
615 Duration::from_secs(2),
616 );
617 let ref_partitioning = if partitioning[0] == 0 {
619 [0, 1, 0, 1]
620 } else {
621 [1, 0, 1, 0]
622 };
623 assert_eq!(partitioning, ref_partitioning);
624 }
625}