Skip to main content

tnc/tensornetwork/
tensordata.rs

1use std::path::PathBuf;
2
3use approx::AbsDiffEq;
4use ndarray::ArrayD;
5use num_complex::Complex64;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    gates::{load_gate, load_gate_adjoint, matrix_adjoint_inplace},
10    io::hdf5::load_data,
11};
12
13pub type DataTensor = ArrayD<Complex64>;
14
15/// The data of a tensor.
16#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum TensorData {
18    /// This is for composite tensors that have not been contracted yet, as well as
19    /// empty tensors in general.
20    #[default]
21    Uncontracted,
22    /// The data is loaded from a HDF5 file.
23    File((PathBuf, bool)),
24    /// A quantum gate. The name must be registered in the gates module.
25    Gate((String, Vec<f64>, bool)),
26    /// A raw vec of complex numbers.
27    Matrix(DataTensor),
28}
29
30impl TensorData {
31    /// Creates a new tensor from raw (flat) data.
32    #[must_use]
33    pub fn new_from_data(dimensions: &[usize], data: Vec<Complex64>) -> Self {
34        Self::Matrix(ArrayD::from_shape_vec(dimensions, data).unwrap())
35    }
36
37    /// Consumes the tensor data and returns the contained tensor.
38    pub fn into_data(self) -> DataTensor {
39        match self {
40            TensorData::Uncontracted => panic!("Cannot convert uncontracted tensor to data"),
41            TensorData::File((filename, adjoint)) => {
42                let mut data = load_data(filename).unwrap();
43                if adjoint {
44                    matrix_adjoint_inplace(&mut data);
45                }
46                data
47            }
48            TensorData::Gate((gatename, angles, adjoint)) => {
49                if adjoint {
50                    load_gate_adjoint(&gatename, &angles)
51                } else {
52                    load_gate(&gatename, &angles)
53                }
54            }
55            TensorData::Matrix(tensor) => tensor,
56        }
57    }
58
59    /// Returns the adjoint of this data.
60    pub fn adjoint(self) -> Self {
61        match self {
62            TensorData::Uncontracted => TensorData::Uncontracted,
63            TensorData::File((filename, adjoint)) => TensorData::File((filename, !adjoint)),
64            TensorData::Gate((name, params, adjoint)) => TensorData::Gate((name, params, !adjoint)),
65            TensorData::Matrix(mut tensor) => {
66                matrix_adjoint_inplace(&mut tensor);
67                TensorData::Matrix(tensor)
68            }
69        }
70    }
71}
72
73impl AbsDiffEq for TensorData {
74    type Epsilon = f64;
75
76    fn default_epsilon() -> Self::Epsilon {
77        f64::EPSILON
78    }
79
80    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
81        match (self, other) {
82            (TensorData::File(l0), TensorData::File(r0)) => l0 == r0,
83            (
84                TensorData::Gate((name_l, angles_l, adjoint_l)),
85                TensorData::Gate((name_r, angles_r, adjoint_r)),
86            ) => {
87                name_l == name_r
88                    && adjoint_l == adjoint_r
89                    && angles_l
90                        .iter()
91                        .zip(angles_r)
92                        .all(|(l, r)| f64::abs_diff_eq(l, r, epsilon))
93            }
94            (TensorData::Matrix(l0), TensorData::Matrix(r0)) => {
95                DataTensor::abs_diff_eq(l0, r0, epsilon)
96            }
97            (TensorData::Uncontracted, TensorData::Uncontracted) => true,
98            _ => false,
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use approx::assert_abs_diff_eq;
106
107    use super::*;
108
109    #[test]
110    #[should_panic(expected = "assert_abs_diff_eq!")]
111    fn gates_eq_different_name() {
112        let g1 = TensorData::Gate((String::from("cx"), vec![], false));
113        let g2 = TensorData::Gate((String::from("CX"), vec![], false));
114        assert_abs_diff_eq!(&g1, &g2);
115    }
116
117    #[test]
118    #[should_panic(expected = "assert_abs_diff_eq!")]
119    fn gates_eq_adjoint() {
120        let g1 = TensorData::Gate((String::from("h"), vec![], false));
121        let g2 = TensorData::Gate((String::from("h"), vec![], true));
122        assert_abs_diff_eq!(&g1, &g2);
123    }
124
125    #[test]
126    #[should_panic(expected = "assert_abs_diff_eq!")]
127    fn gates_eq_different_angles() {
128        let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
129        let g2 = TensorData::Gate((String::from("u"), vec![1.4, -2.0, -3.0], false));
130        assert_abs_diff_eq!(&g1, &g2);
131    }
132
133    #[test]
134    #[should_panic(expected = "assert_abs_diff_eq!")]
135    fn eq_different_data() {
136        let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
137        let g2 = TensorData::new_from_data(&[], vec![Complex64::ONE]);
138        assert_abs_diff_eq!(&g1, &g2);
139    }
140}