1use std::path::Path;
18
19use hdf5_metno::{File, Result};
20use ndarray::Array;
21use num_complex::Complex64;
22use tetra::Tensor as DataTensor;
23
24use crate::tensornetwork::tensor::Tensor;
25use crate::tensornetwork::tensordata::TensorData;
26
27pub fn load_tensor<P>(filename: P) -> Result<Tensor>
29where
30 P: AsRef<Path>,
31{
32 let file = File::open(filename)?;
33 read_tensor(&file)
34}
35
36pub fn load_data<P>(filename: P) -> Result<DataTensor>
38where
39 P: AsRef<Path>,
40{
41 let file = File::open(filename)?;
42 read_data(&file)
43}
44
45pub fn store_data<P>(filename: P, tensor: &DataTensor) -> Result<()>
47where
48 P: AsRef<Path>,
49{
50 let file = File::create(filename)?;
51 write_data(&file, tensor)
52}
53
54fn read_tensor(file: &File) -> Result<Tensor> {
55 let gr = file.group("/tensors")?;
56 let tensor_names = gr.member_names()?;
57
58 let out_tensor = gr.dataset("-1")?;
60 let out_tensor_bids = out_tensor.attr("bids")?;
61 let out_bond_ids = out_tensor_bids.read_1d::<usize>()?;
62
63 let mut new_tensor_network = Tensor::default();
64
65 for tensor_name in tensor_names {
66 if tensor_name == "-1" {
67 continue;
68 }
69 let tensor = gr.dataset(&tensor_name)?;
70 let bond_ids = tensor.attr("bids").unwrap().read_1d::<usize>()?;
71 let tensor_dataset = gr.dataset(&tensor_name).unwrap().read_dyn::<Complex64>()?;
72 let tensor_shape = tensor_dataset.shape().to_vec();
73 let bond_dims = tensor_shape.iter().map(|s| *s as u64).collect();
74 let mut new_tensor = Tensor::new(bond_ids.to_vec(), bond_dims);
75 let (data, offset) = tensor_dataset.into_raw_vec_and_offset();
76 assert_eq!(offset, Some(0));
77 new_tensor.set_tensor_data(TensorData::Matrix(DataTensor::new_from_flat(
78 &tensor_shape,
79 data,
80 None,
81 )));
82 new_tensor_network.push_tensor(new_tensor);
83 }
84 new_tensor_network.set_legs(out_bond_ids.to_vec());
85
86 Ok(new_tensor_network)
87}
88
89fn read_data(file: &File) -> Result<DataTensor> {
90 let gr = file.group("/tensors")?;
91 let tensor_name = gr.member_names()?;
92
93 let tensor_dataset = gr
94 .dataset(&tensor_name[0])
95 .unwrap()
96 .read_dyn::<Complex64>()?;
97 let tensor_shape = tensor_dataset.shape().to_vec();
98 let (data, offset) = tensor_dataset.into_raw_vec_and_offset();
99 assert_eq!(offset, Some(0));
100 Ok(DataTensor::new_from_flat(&tensor_shape, data, None))
101}
102
103fn write_data(file: &File, tensor: &DataTensor) -> Result<()> {
104 let gr = file.create_group("/tensors")?;
105 let data = tensor.elements().into_owned();
106 let shape = tensor.shape();
107 let data = Array::from_shape_vec(shape, data)?;
108 let tensor_dataset = gr.new_dataset_builder().with_data(&data);
109 tensor_dataset.create("-1")?;
110 file.flush()
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 use std::iter::zip;
118
119 use float_cmp::assert_approx_eq;
120 use hdf5_metno::{AttributeBuilder, File, Result};
121 use ndarray::array;
122 use num_complex::Complex64;
123 use rand::{
124 distr::{Alphanumeric, SampleString},
125 rng,
126 };
127 use tetra::Tensor as DataTensor;
128
129 use crate::tensornetwork::tensor::Tensor;
130 use crate::tensornetwork::tensordata::TensorData;
131
132 fn new_in_memory_file() -> Result<File> {
136 let random_filename = Alphanumeric.sample_string(&mut rng(), 8);
137 File::with_options()
138 .with_access_plist(|p| p.core_filebacked(false))
139 .create(random_filename)
140 }
141
142 fn create_hdf5_tensor() -> Result<File> {
143 let new_file = new_in_memory_file()?;
144 let tensor_group = new_file.create_group("./tensors")?;
145 let dataset_builder = tensor_group.new_dataset_builder();
146 let dataset = dataset_builder.empty::<Complex64>().create("-1")?;
147 let attribute = AttributeBuilder::new(&dataset);
148 let bid = array![0, 1];
149 let attribute = attribute.with_data(&bid);
150 attribute.create("bids")?;
151
152 let data = Array::from_shape_vec(
153 (2, 2),
154 vec![
155 Complex64::new(1.0, 0.0),
156 Complex64::new(0.0, 2.0),
157 Complex64::new(3.0, 0.0),
158 Complex64::new(0.0, 1.0),
159 ],
160 )?;
161 let dataset_builder2 = tensor_group.new_dataset_builder();
162 let dataset_data_builder2 = dataset_builder2.with_data(&data);
163 let dataset2 = dataset_data_builder2.create("0")?;
164 let attribute2 = AttributeBuilder::new(&dataset2);
165 let bid2 = array![0, 1];
166 let attribute2 = attribute2.with_data(&bid2);
167 attribute2.create("bids")?;
168
169 new_file.flush()?;
170 Ok(new_file)
171 }
172
173 fn create_hdf5_data() -> Result<File> {
174 let new_file = new_in_memory_file()?;
175 let tensor_group = new_file.create_group("./tensors")?;
176 let dataset_builder = tensor_group.new_dataset_builder();
177 let data = Array::from_shape_vec(
178 (2, 2),
179 vec![
180 Complex64::new(1.0, 0.0),
181 Complex64::new(0.0, 2.0),
182 Complex64::new(3.0, 0.0),
183 Complex64::new(0.0, 1.0),
184 ],
185 )?;
186 let dataset_data_builder = dataset_builder.with_data(&data);
187 dataset_data_builder.create("-1")?;
188 new_file.flush()?;
189 Ok(new_file)
190 }
191
192 #[test]
193 fn test_load_data() {
194 let file = create_hdf5_data().unwrap();
195 let tensor_data = read_data(&file).unwrap();
196
197 let ref_data = array![
198 Complex64::new(1.0, 0.0),
199 Complex64::new(0.0, 2.0),
200 Complex64::new(3.0, 0.0),
201 Complex64::new(0.0, 1.0),
202 ];
203 for (u, v) in zip(ref_data.iter(), tensor_data.elements().iter()) {
204 assert_approx_eq!(f64, u.re, v.re, epsilon = 1e-8);
205 assert_approx_eq!(f64, u.im, v.im, epsilon = 1e-8);
206 }
207 }
208
209 #[test]
210 fn test_load_tensor() {
211 let file = create_hdf5_tensor().unwrap();
212 let tensor = read_tensor(&file).unwrap();
213
214 let mut ref_tn = Tensor::default();
215 let mut ref_tensor = Tensor::new(vec![0, 1], vec![2, 2]);
216 ref_tensor.set_tensor_data(TensorData::new_from_data(
217 &[2, 2],
218 vec![
219 Complex64::new(1.0, 0.0),
220 Complex64::new(0.0, 2.0),
221 Complex64::new(3.0, 0.0),
222 Complex64::new(0.0, 1.0),
223 ],
224 None,
225 ));
226 ref_tn.push_tensor(ref_tensor);
227 ref_tn.set_legs(vec![0, 1]);
228 assert_approx_eq!(&Tensor, &tensor, &ref_tn);
229 }
230
231 #[test]
232 fn test_write_read() {
233 let file = new_in_memory_file().unwrap();
234 let data = vec![
235 Complex64::new(1.0, 0.0),
236 Complex64::new(0.0, -2.0),
237 Complex64::new(-3.0, 0.0),
238 Complex64::new(-2.0, -1.0),
239 Complex64::new(0.0, 0.0),
240 Complex64::new(0.5, 2.0),
241 ];
242 let tensor = DataTensor::new_from_flat(&[2, 3], data, None);
243
244 write_data(&file, &tensor).unwrap();
245 let read = read_data(&file).unwrap();
246
247 assert_approx_eq!(&DataTensor, &tensor, &read);
248 }
249}