tnc/tensornetwork/
tensordata.rs1use std::path::PathBuf;
2
3use approx::AbsDiffEq;
4use ndarray::ArrayD;
5use num_complex::Complex64;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 gates::{load_gate, load_gate_adjoint, matrix_adjoint_inplace},
10 io::hdf5::load_data,
11};
12
13pub type DataTensor = ArrayD<Complex64>;
14
15#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum TensorData {
18 #[default]
21 Uncontracted,
22 File((PathBuf, bool)),
24 Gate((String, Vec<f64>, bool)),
26 Matrix(DataTensor),
28}
29
30impl TensorData {
31 #[must_use]
33 pub fn new_from_data(dimensions: &[usize], data: Vec<Complex64>) -> Self {
34 Self::Matrix(ArrayD::from_shape_vec(dimensions, data).unwrap())
35 }
36
37 pub fn into_data(self) -> DataTensor {
39 match self {
40 TensorData::Uncontracted => panic!("Cannot convert uncontracted tensor to data"),
41 TensorData::File((filename, adjoint)) => {
42 let mut data = load_data(filename).unwrap();
43 if adjoint {
44 matrix_adjoint_inplace(&mut data);
45 }
46 data
47 }
48 TensorData::Gate((gatename, angles, adjoint)) => {
49 if adjoint {
50 load_gate_adjoint(&gatename, &angles)
51 } else {
52 load_gate(&gatename, &angles)
53 }
54 }
55 TensorData::Matrix(tensor) => tensor,
56 }
57 }
58
59 pub fn adjoint(self) -> Self {
61 match self {
62 TensorData::Uncontracted => TensorData::Uncontracted,
63 TensorData::File((filename, adjoint)) => TensorData::File((filename, !adjoint)),
64 TensorData::Gate((name, params, adjoint)) => TensorData::Gate((name, params, !adjoint)),
65 TensorData::Matrix(mut tensor) => {
66 matrix_adjoint_inplace(&mut tensor);
67 TensorData::Matrix(tensor)
68 }
69 }
70 }
71}
72
73impl AbsDiffEq for TensorData {
74 type Epsilon = f64;
75
76 fn default_epsilon() -> Self::Epsilon {
77 f64::EPSILON
78 }
79
80 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
81 match (self, other) {
82 (TensorData::File(l0), TensorData::File(r0)) => l0 == r0,
83 (
84 TensorData::Gate((name_l, angles_l, adjoint_l)),
85 TensorData::Gate((name_r, angles_r, adjoint_r)),
86 ) => {
87 name_l == name_r
88 && adjoint_l == adjoint_r
89 && angles_l
90 .iter()
91 .zip(angles_r)
92 .all(|(l, r)| f64::abs_diff_eq(l, r, epsilon))
93 }
94 (TensorData::Matrix(l0), TensorData::Matrix(r0)) => {
95 DataTensor::abs_diff_eq(l0, r0, epsilon)
96 }
97 (TensorData::Uncontracted, TensorData::Uncontracted) => true,
98 _ => false,
99 }
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use approx::assert_abs_diff_eq;
106
107 use super::*;
108
109 #[test]
110 #[should_panic(expected = "assert_abs_diff_eq!")]
111 fn gates_eq_different_name() {
112 let g1 = TensorData::Gate((String::from("cx"), vec![], false));
113 let g2 = TensorData::Gate((String::from("CX"), vec![], false));
114 assert_abs_diff_eq!(&g1, &g2);
115 }
116
117 #[test]
118 #[should_panic(expected = "assert_abs_diff_eq!")]
119 fn gates_eq_adjoint() {
120 let g1 = TensorData::Gate((String::from("h"), vec![], false));
121 let g2 = TensorData::Gate((String::from("h"), vec![], true));
122 assert_abs_diff_eq!(&g1, &g2);
123 }
124
125 #[test]
126 #[should_panic(expected = "assert_abs_diff_eq!")]
127 fn gates_eq_different_angles() {
128 let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
129 let g2 = TensorData::Gate((String::from("u"), vec![1.4, -2.0, -3.0], false));
130 assert_abs_diff_eq!(&g1, &g2);
131 }
132
133 #[test]
134 #[should_panic(expected = "assert_abs_diff_eq!")]
135 fn eq_different_data() {
136 let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
137 let g2 = TensorData::new_from_data(&[], vec![Complex64::ONE]);
138 assert_abs_diff_eq!(&g1, &g2);
139 }
140}