tnc/contractionpath/repartitioning/
genetic.rs

1//! Repartitioning using a genetic algorithm.
2
3use genetic_algorithm::{
4    crossover::CrossoverUniform,
5    fitness::{Fitness, FitnessChromosome, FitnessOrdering, FitnessValue},
6    genotype::{Genotype, RangeGenotype},
7    mutate::MutateSingleGene,
8    select::SelectTournament,
9    strategy::{evolve::Evolve, Strategy},
10};
11use ordered_float::NotNan;
12use rand::rngs::StdRng;
13
14use crate::{
15    contractionpath::{
16        communication_schemes::CommunicationScheme,
17        contraction_cost::{compute_memory_requirements, contract_size_tensors_exact},
18        repartitioning::compute_solution,
19    },
20    tensornetwork::tensor::Tensor,
21};
22
23#[derive(Clone, Debug)]
24struct PartitioningFitness<'a> {
25    tensor: &'a Tensor,
26    communication_scheme: CommunicationScheme,
27    memory_limit: Option<f64>,
28}
29
30impl PartitioningFitness<'_> {
31    fn calculate_fitness(&self, partitioning: &[usize]) -> NotNan<f64> {
32        // Construct the tensor network and contraction path from the partitioning
33        let (partitioned_tn, path, cost, _) =
34            compute_solution::<StdRng>(self.tensor, partitioning, self.communication_scheme, None);
35
36        // Compute memory usage
37        let mem = compute_memory_requirements(
38            partitioned_tn.tensors(),
39            &path,
40            contract_size_tensors_exact,
41        );
42
43        // If the memory limit is exceeded, return infinity
44        let score = if self.memory_limit.is_some_and(|limit| mem > limit) {
45            f64::INFINITY
46        } else {
47            cost
48        };
49        NotNan::new(score).unwrap()
50    }
51}
52
53impl Fitness for PartitioningFitness<'_> {
54    type Genotype = RangeGenotype<usize>;
55
56    fn calculate_for_chromosome(
57        &mut self,
58        chromosome: &FitnessChromosome<Self>,
59        _genotype: &Self::Genotype,
60    ) -> Option<FitnessValue> {
61        Some(self.calculate_fitness(&chromosome.genes))
62    }
63}
64
65/// Balances partitions using a genetic algorithm. Finds the partitioning that reduces
66/// the total contraction cost.
67pub fn balance_partitions(
68    tensor: &Tensor,
69    num_partitions: usize,
70    initial_partitioning: &[usize],
71    communication_scheme: CommunicationScheme,
72    memory_limit: Option<f64>,
73) -> (Vec<usize>, f64) {
74    // Chromosomes: Possible partitions, e.g. [0, 1, 0, 2, 2, 1, 0, 0, 1, 1]
75    // Genes: tensor (in vector)
76    // Alleles: partition id
77
78    let num_tensors = initial_partitioning.len();
79
80    let genotype = RangeGenotype::builder()
81        .with_genes_size(num_tensors)
82        .with_allele_range(0..=num_partitions - 1)
83        .with_seed_genes_list(vec![initial_partitioning.to_vec()])
84        .build()
85        .unwrap();
86
87    let fitness = PartitioningFitness {
88        tensor,
89        communication_scheme,
90        memory_limit,
91    };
92
93    let evolve = Evolve::builder()
94        .with_genotype(genotype)
95        .with_target_population_size(100)
96        .with_max_stale_generations(100)
97        .with_fitness(fitness)
98        .with_fitness_ordering(FitnessOrdering::Minimize)
99        .with_mutate(MutateSingleGene::new(0.2))
100        .with_crossover(CrossoverUniform::new(1.0, 1.0))
101        .with_select(SelectTournament::new(1.0, 0.02, 4))
102        // .with_reporter(EvolveReporterDuration::new())
103        .with_par_fitness(true)
104        .with_rng_seed_from_u64(0)
105        .call()
106        .unwrap();
107
108    evolve
109        .best_genes_and_fitness_score()
110        .map(|(partitioning, score)| (partitioning, score.into_inner()))
111        .unwrap()
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn small_partitioning() {
120        let t1 = Tensor::new_from_const(vec![0, 1], 2);
121        let t2 = Tensor::new_from_const(vec![2, 3], 2);
122        let t3 = Tensor::new_from_const(vec![0, 1, 4], 2);
123        let t4 = Tensor::new_from_const(vec![2, 3, 4], 2);
124        let tn = Tensor::new_composite(vec![t1, t2, t3, t4]);
125        let initial_partitioning = vec![0, 0, 1, 1];
126
127        let (partitioning, _) = balance_partitions(
128            &tn,
129            2,
130            &initial_partitioning,
131            CommunicationScheme::RandomGreedy,
132            None,
133        );
134        // Normalize for comparability
135        let ref_partitioning = if partitioning[0] == 0 {
136            [0, 1, 0, 1]
137        } else {
138            [1, 0, 1, 0]
139        };
140        assert_eq!(partitioning, ref_partitioning);
141    }
142}