1use std::path::PathBuf;
2
3use float_cmp::{ApproxEq, F64Margin};
4use num_complex::Complex64;
5use serde::{Deserialize, Serialize};
6use tetra::{Layout, Tensor as DataTensor};
7
8use crate::{
9 gates::{load_gate, load_gate_adjoint, matrix_adjoint_inplace},
10 hdf5::load_data,
11};
12
13#[derive(Default, Debug, Clone, Serialize, Deserialize)]
15pub enum TensorData {
16 #[default]
19 Uncontracted,
20 File((PathBuf, bool)),
22 Gate((String, Vec<f64>, bool)),
24 Matrix(DataTensor),
26}
27
28impl TensorData {
29 #[must_use]
31 pub fn new_from_data(
32 dimensions: &[usize],
33 data: Vec<Complex64>,
34 layout: Option<Layout>,
35 ) -> Self {
36 Self::Matrix(DataTensor::new_from_flat(dimensions, data, layout))
37 }
38
39 pub fn into_data(self) -> DataTensor {
41 match self {
42 TensorData::Uncontracted => panic!("Cannot convert uncontracted tensor to data"),
43 TensorData::File((filename, adjoint)) => {
44 let mut data = load_data(filename).unwrap();
45 if adjoint {
46 matrix_adjoint_inplace(&mut data);
47 }
48 data
49 }
50 TensorData::Gate((gatename, angles, adjoint)) => {
51 if adjoint {
52 load_gate_adjoint(&gatename, &angles)
53 } else {
54 load_gate(&gatename, &angles)
55 }
56 }
57 TensorData::Matrix(tensor) => tensor,
58 }
59 }
60
61 pub fn adjoint(self) -> Self {
63 match self {
64 TensorData::Uncontracted => TensorData::Uncontracted,
65 TensorData::File((filename, adjoint)) => TensorData::File((filename, !adjoint)),
66 TensorData::Gate((name, params, adjoint)) => TensorData::Gate((name, params, !adjoint)),
67 TensorData::Matrix(mut tensor) => {
68 matrix_adjoint_inplace(&mut tensor);
69 TensorData::Matrix(tensor)
70 }
71 }
72 }
73}
74
75impl ApproxEq for &TensorData {
76 type Margin = F64Margin;
77
78 fn approx_eq<M: Into<Self::Margin>>(self, other: Self, margin: M) -> bool {
79 let margin = margin.into();
80 match (self, other) {
81 (TensorData::File(l0), TensorData::File(r0)) => l0 == r0,
82 (
83 TensorData::Gate((name_l, angles_l, adjoint_l)),
84 TensorData::Gate((name_r, angles_r, adjoint_r)),
85 ) => name_l == name_r && adjoint_l == adjoint_r && angles_l.approx_eq(angles_r, margin),
86 (TensorData::Matrix(l0), TensorData::Matrix(r0)) => l0.approx_eq(r0, margin),
87 (TensorData::Uncontracted, TensorData::Uncontracted) => true,
88 _ => false,
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use float_cmp::assert_approx_eq;
96
97 use super::*;
98
99 #[test]
100 #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
101 fn gates_eq_different_name() {
102 let g1 = TensorData::Gate((String::from("cx"), vec![], false));
103 let g2 = TensorData::Gate((String::from("CX"), vec![], false));
104 assert_approx_eq!(&TensorData, &g1, &g2);
105 }
106
107 #[test]
108 #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
109 fn gates_eq_adjoint() {
110 let g1 = TensorData::Gate((String::from("h"), vec![], false));
111 let g2 = TensorData::Gate((String::from("h"), vec![], true));
112 assert_approx_eq!(&TensorData, &g1, &g2);
113 }
114
115 #[test]
116 #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
117 fn gates_eq_different_angles() {
118 let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
119 let g2 = TensorData::Gate((String::from("u"), vec![1.4, -2.0, -3.0], false));
120 assert_approx_eq!(&TensorData, &g1, &g2);
121 }
122
123 #[test]
124 #[should_panic(expected = "assertion failed: `(left approx_eq right)`")]
125 fn eq_different_data() {
126 let g1 = TensorData::Gate((String::from("u"), vec![1.4, 2.0, -3.0], false));
127 let g2 = TensorData::new_from_data(&[], vec![Complex64::ONE], None);
128 assert_approx_eq!(&TensorData, &g1, &g2);
129 }
130}