tnc/contractionpath/repartitioning/
genetic.rs1use genetic_algorithm::{
4 crossover::CrossoverUniform,
5 fitness::{Fitness, FitnessChromosome, FitnessOrdering, FitnessValue},
6 genotype::{Genotype, RangeGenotype},
7 mutate::MutateSingleGene,
8 select::SelectTournament,
9 strategy::{evolve::Evolve, Strategy},
10};
11use ordered_float::NotNan;
12use rand::rngs::StdRng;
13
14use crate::{
15 contractionpath::{
16 communication_schemes::CommunicationScheme,
17 contraction_cost::{compute_memory_requirements, contract_size_tensors_exact},
18 repartitioning::compute_solution,
19 },
20 tensornetwork::tensor::Tensor,
21};
22
23#[derive(Clone, Debug)]
24struct PartitioningFitness<'a> {
25 tensor: &'a Tensor,
26 communication_scheme: CommunicationScheme,
27 memory_limit: Option<f64>,
28}
29
30impl PartitioningFitness<'_> {
31 fn calculate_fitness(&self, partitioning: &[usize]) -> NotNan<f64> {
32 let (partitioned_tn, path, cost, _) =
34 compute_solution::<StdRng>(self.tensor, partitioning, self.communication_scheme, None);
35
36 let mem = compute_memory_requirements(
38 partitioned_tn.tensors(),
39 &path,
40 contract_size_tensors_exact,
41 );
42
43 let score = if self.memory_limit.is_some_and(|limit| mem > limit) {
45 f64::INFINITY
46 } else {
47 cost
48 };
49 NotNan::new(score).unwrap()
50 }
51}
52
53impl Fitness for PartitioningFitness<'_> {
54 type Genotype = RangeGenotype<usize>;
55
56 fn calculate_for_chromosome(
57 &mut self,
58 chromosome: &FitnessChromosome<Self>,
59 _genotype: &Self::Genotype,
60 ) -> Option<FitnessValue> {
61 Some(self.calculate_fitness(&chromosome.genes))
62 }
63}
64
65pub fn balance_partitions(
68 tensor: &Tensor,
69 num_partitions: usize,
70 initial_partitioning: &[usize],
71 communication_scheme: CommunicationScheme,
72 memory_limit: Option<f64>,
73) -> (Vec<usize>, f64) {
74 let num_tensors = initial_partitioning.len();
79
80 let genotype = RangeGenotype::builder()
81 .with_genes_size(num_tensors)
82 .with_allele_range(0..=num_partitions - 1)
83 .with_seed_genes_list(vec![initial_partitioning.to_vec()])
84 .build()
85 .unwrap();
86
87 let fitness = PartitioningFitness {
88 tensor,
89 communication_scheme,
90 memory_limit,
91 };
92
93 let evolve = Evolve::builder()
94 .with_genotype(genotype)
95 .with_target_population_size(100)
96 .with_max_stale_generations(100)
97 .with_fitness(fitness)
98 .with_fitness_ordering(FitnessOrdering::Minimize)
99 .with_mutate(MutateSingleGene::new(0.2))
100 .with_crossover(CrossoverUniform::new(1.0, 1.0))
101 .with_select(SelectTournament::new(1.0, 0.02, 4))
102 .with_par_fitness(true)
104 .with_rng_seed_from_u64(0)
105 .call()
106 .unwrap();
107
108 evolve
109 .best_genes_and_fitness_score()
110 .map(|(partitioning, score)| (partitioning, score.into_inner()))
111 .unwrap()
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn small_partitioning() {
120 let t1 = Tensor::new_from_const(vec![0, 1], 2);
121 let t2 = Tensor::new_from_const(vec![2, 3], 2);
122 let t3 = Tensor::new_from_const(vec![0, 1, 4], 2);
123 let t4 = Tensor::new_from_const(vec![2, 3, 4], 2);
124 let tn = Tensor::new_composite(vec![t1, t2, t3, t4]);
125 let initial_partitioning = vec![0, 0, 1, 1];
126
127 let (partitioning, _) = balance_partitions(
128 &tn,
129 2,
130 &initial_partitioning,
131 CommunicationScheme::RandomGreedy,
132 None,
133 );
134 let ref_partitioning = if partitioning[0] == 0 {
136 [0, 1, 0, 1]
137 } else {
138 [1, 0, 1, 0]
139 };
140 assert_eq!(partitioning, ref_partitioning);
141 }
142}