Skip to main content

tnc/builders/
tensorgeneration.rs

1use itertools::Itertools;
2use ndarray::Dim;
3use num_complex::Complex64;
4use rand::distr::Uniform;
5use rand::Rng;
6
7use crate::tensornetwork::tensordata::{DataTensor, TensorData};
8
9/// Generates random sparse [`DataTensor`] object.
10/// Fills in sparse tensor based on `sparsity` value (defaults to `0.5`).
11///
12/// # Examples
13/// ```
14/// # use tnc::builders::tensorgeneration::random_sparse_tensor_data_with_rng;
15/// let shape = vec![5, 4, 3];
16/// random_sparse_tensor_data_with_rng(&shape, None, &mut rand::thread_rng());
17/// ```
18pub fn random_sparse_tensor_data_with_rng<R>(
19    dims: &[usize],
20    sparsity: Option<f32>,
21    rng: &mut R,
22) -> TensorData
23where
24    R: Rng,
25{
26    let sparsity = if let Some(sparsity) = sparsity {
27        assert!((0.0..=1.0).contains(&sparsity));
28        sparsity
29    } else {
30        0.5
31    };
32
33    let ranges = dims
34        .iter()
35        .map(|i| Uniform::new(0, *i).unwrap())
36        .collect_vec();
37    let size = dims.iter().product::<usize>();
38    let mut tensor = DataTensor::zeros(dims);
39
40    let mut nnz = 0;
41    while (nnz as f32 / size as f32) < sparsity {
42        let loc = ranges.iter().map(|r| rng.sample(r)).collect_vec();
43        let val = Complex64::new(rng.random(), rng.random());
44        let dim = Dim(loc);
45        let elem = tensor.get_mut(dim).unwrap();
46        if *elem != Complex64::ZERO {
47            continue; // Skip if the location is already non-zero
48        }
49        *elem = val;
50        nnz += 1;
51    }
52
53    TensorData::Matrix(tensor)
54}
55
56/// Generates random sparse [`DataTensor`] object.
57/// Fills in sparse tensor based on `sparsity` value (defaults to `0.5`). Uses the thread-local random number generator.
58///
59/// # Examples
60/// ```
61/// # use tnc::builders::tensorgeneration::random_sparse_tensor_data;
62/// let shape = vec![5,4,3];
63/// let r_tensor = random_sparse_tensor_data(&shape, None);
64/// ```
65#[must_use]
66pub fn random_sparse_tensor_data(shape: &[usize], sparsity: Option<f32>) -> TensorData {
67    random_sparse_tensor_data_with_rng(shape, sparsity, &mut rand::rng())
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn test_random_sparse_tensor_data() {
76        let shape = vec![5, 4, 3];
77        let sparsity = 0.3;
78        let tensor_data = random_sparse_tensor_data(&shape, Some(sparsity));
79        let TensorData::Matrix(tensor) = tensor_data else {
80            panic!("Expected TensorData::Matrix variant");
81        };
82        let total_elements = shape.iter().product::<usize>();
83        let non_zero_elements = tensor.iter().filter(|&&x| x != Complex64::ZERO).count();
84        let actual_sparsity = non_zero_elements as f32 / total_elements as f32;
85        assert!(
86            actual_sparsity >= sparsity,
87            "Expected sparsity around {sparsity}, but got {actual_sparsity}",
88        );
89    }
90}