tnc/tensornetwork/
tensordata.rs

1use std::path::PathBuf;
2
3use float_cmp::{ApproxEq, F64Margin};
4use num_complex::Complex64;
5use serde::{Deserialize, Serialize};
6use tetra::{Layout, Tensor as DataTensor};
7
8use crate::{
9    gates::{load_gate, load_gate_adjoint, matrix_adjoint_inplace},
10    hdf5::load_data,
11};
12
13/// The data of a tensor.
14#[derive(Default, Debug, Clone, Serialize, Deserialize)]
15pub enum TensorData {
16    /// This is for composite tensors that have not been contracted yet, as well as
17    /// empty tensors in general.
18    #[default]
19    Uncontracted,
20    /// The data is loaded from a HDF5 file.
21    File((PathBuf, bool)),
22    /// A quantum gate. The name must be registered in the gates module.
23    Gate((String, Vec<f64>, bool)),
24    /// A raw vec of complex numbers.
25    Matrix(DataTensor),
26}
27
28impl TensorData {
29    /// Creates a new tensor from raw (flat) data.
30    #[must_use]
31    pub fn new_from_data(
32        dimensions: &[usize],
33        data: Vec<Complex64>,
34        layout: Option<Layout>,
35    ) -> Self {
36        Self::Matrix(DataTensor::new_from_flat(dimensions, data, layout))
37    }
38
39    /// Consumes the tensor data and returns the contained tensor.
40    pub fn into_data(self) -> DataTensor {
41        match self {
42            TensorData::Uncontracted => panic!("Cannot convert uncontracted tensor to data"),
43            TensorData::File((filename, adjoint)) => {
44                let mut data = load_data(filename).unwrap();
45                if adjoint {
46                    matrix_adjoint_inplace(&mut data);
47                }
48                data
49            }
50            TensorData::Gate((gatename, angles, adjoint)) => {
51                if adjoint {
52                    load_gate_adjoint(&gatename, &angles)
53                } else {
54                    load_gate(&gatename, &angles)
55                }
56            }
57            TensorData::Matrix(tensor) => tensor,
58        }
59    }
60
61    /// Returns the adjoint of this data.
62    pub fn adjoint(self) -> Self {
63        match self {
64            TensorData::Uncontracted => TensorData::Uncontracted,
65            TensorData::File((filename, adjoint)) => TensorData::File((filename, !adjoint)),
66            TensorData::Gate((name, params, adjoint)) => TensorData::Gate((name, params, !adjoint)),
67            TensorData::Matrix(mut tensor) => {
68                matrix_adjoint_inplace(&mut tensor);
69                TensorData::Matrix(tensor)
70            }
71        }
72    }
73}
74
75impl ApproxEq for &TensorData {
76    type Margin = F64Margin;
77
78    fn approx_eq<M: Into<Self::Margin>>(self, other: Self, margin: M) -> bool {
79        let margin = margin.into();
80        match (self, other) {
81            (TensorData::File(l0), TensorData::File(r0)) => l0 == r0,
82            (
83                TensorData::Gate((name_l, angles_l, adjoint_l)),
84                TensorData::Gate((name_r, angles_r, adjoint_r)),
85            ) => name_l == name_r && adjoint_l == adjoint_r && angles_l.approx_eq(angles_r, margin),
86            (TensorData::Matrix(l0), TensorData::Matrix(r0)) => l0.approx_eq(r0, margin),
87            (TensorData::Uncontracted, TensorData::Uncontracted) => true,
88            _ => false,
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use float_cmp::assert_approx_eq;
96
97    use super::*;
98
99    #[test]
100    #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
101    fn gates_eq_different_name() {
102        let g1 = TensorData::Gate((String::from("cx"), vec![], false));
103        let g2 = TensorData::Gate((String::from("CX"), vec![], false));
104        assert_approx_eq!(&TensorData, &g1, &g2);
105    }
106
107    #[test]
108    #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
109    fn gates_eq_adjoint() {
110        let g1 = TensorData::Gate((String::from("h"), vec![], false));
111        let g2 = TensorData::Gate((String::from("h"), vec![], true));
112        assert_approx_eq!(&TensorData, &g1, &g2);
113    }
114
115    #[test]
116    #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
117    fn gates_eq_different_angles() {
118        let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
119        let g2 = TensorData::Gate((String::from("u"), vec![1.4, -2.0, -3.0], false));
120        assert_approx_eq!(&TensorData, &g1, &g2);
121    }
122
123    #[test]
124    #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
125    fn eq_different_data() {
126        let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
127        let g2 = TensorData::new_from_data(&[], vec![Complex64::ONE], None);
128        assert_approx_eq!(&TensorData, &g1, &g2);
129    }
130}