tnc/builders/
tensorgeneration.rs1use itertools::Itertools;
2use ndarray::Dim;
3use num_complex::Complex64;
4use rand::distr::Uniform;
5use rand::Rng;
6
7use crate::tensornetwork::tensordata::{DataTensor, TensorData};
8
9pub fn random_sparse_tensor_data_with_rng<R>(
19 dims: &[usize],
20 sparsity: Option<f32>,
21 rng: &mut R,
22) -> TensorData
23where
24 R: Rng,
25{
26 let sparsity = if let Some(sparsity) = sparsity {
27 assert!((0.0..=1.0).contains(&sparsity));
28 sparsity
29 } else {
30 0.5
31 };
32
33 let ranges = dims
34 .iter()
35 .map(|i| Uniform::new(0, *i).unwrap())
36 .collect_vec();
37 let size = dims.iter().product::<usize>();
38 let mut tensor = DataTensor::zeros(dims);
39
40 let mut nnz = 0;
41 while (nnz as f32 / size as f32) < sparsity {
42 let loc = ranges.iter().map(|r| rng.sample(r)).collect_vec();
43 let val = Complex64::new(rng.random(), rng.random());
44 let dim = Dim(loc);
45 let elem = tensor.get_mut(dim).unwrap();
46 if *elem != Complex64::ZERO {
47 continue; }
49 *elem = val;
50 nnz += 1;
51 }
52
53 TensorData::Matrix(tensor)
54}
55
56#[must_use]
66pub fn random_sparse_tensor_data(shape: &[usize], sparsity: Option<f32>) -> TensorData {
67 random_sparse_tensor_data_with_rng(shape, sparsity, &mut rand::rng())
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn test_random_sparse_tensor_data() {
76 let shape = vec![5, 4, 3];
77 let sparsity = 0.3;
78 let tensor_data = random_sparse_tensor_data(&shape, Some(sparsity));
79 let TensorData::Matrix(tensor) = tensor_data else {
80 panic!("Expected TensorData::Matrix variant");
81 };
82 let total_elements = shape.iter().product::<usize>();
83 let non_zero_elements = tensor.iter().filter(|&&x| x != Complex64::ZERO).count();
84 let actual_sparsity = non_zero_elements as f32 / total_elements as f32;
85 assert!(
86 actual_sparsity >= sparsity,
87 "Expected sparsity around {sparsity}, but got {actual_sparsity}",
88 );
89 }
90}