1use std::path::Path;
18
19use hdf5_metno::{File, Result};
20use num_complex::Complex64;
21
22use crate::tensornetwork::tensor::Tensor;
23use crate::tensornetwork::tensordata::{DataTensor, TensorData};
24
25pub fn load_tensor<P>(filename: P) -> Result<Tensor>
27where
28 P: AsRef<Path>,
29{
30 let file = File::open(filename)?;
31 read_tensor(&file)
32}
33
34pub fn load_data<P>(filename: P) -> Result<DataTensor>
36where
37 P: AsRef<Path>,
38{
39 let file = File::open(filename)?;
40 read_data(&file)
41}
42
43pub fn store_data<P>(filename: P, tensor: &DataTensor) -> Result<()>
45where
46 P: AsRef<Path>,
47{
48 let file = File::create(filename)?;
49 write_data(&file, tensor)
50}
51
52fn read_tensor(file: &File) -> Result<Tensor> {
53 let gr = file.group("/tensors")?;
54 let tensor_names = gr.member_names()?;
55
56 let out_tensor = gr.dataset("-1")?;
58 let out_tensor_bids = out_tensor.attr("bids")?;
59 let out_bond_ids = out_tensor_bids.read_1d::<usize>()?;
60
61 let mut new_tensor_network = Tensor::default();
62
63 for tensor_name in tensor_names {
64 if tensor_name == "-1" {
65 continue;
66 }
67 let tensor = gr.dataset(&tensor_name)?;
68 let bond_ids = tensor.attr("bids").unwrap().read_1d::<usize>()?;
69 let tensor_dataset = gr.dataset(&tensor_name).unwrap().read_dyn::<Complex64>()?;
70 let tensor_shape = tensor_dataset.shape();
71 let bond_dims = tensor_shape.iter().map(|s| *s as u64).collect();
72 let mut new_tensor = Tensor::new(bond_ids.to_vec(), bond_dims);
73 new_tensor.set_tensor_data(TensorData::Matrix(tensor_dataset));
74 new_tensor_network.push_tensor(new_tensor);
75 }
76 new_tensor_network.set_legs(out_bond_ids.to_vec());
77
78 Ok(new_tensor_network)
79}
80
81fn read_data(file: &File) -> Result<DataTensor> {
82 let gr = file.group("/tensors")?;
83 let tensor_name = gr.member_names()?;
84
85 let tensor_dataset = gr
86 .dataset(&tensor_name[0])
87 .unwrap()
88 .read_dyn::<Complex64>()?;
89 Ok(tensor_dataset)
90}
91
92fn write_data(file: &File, tensor: &DataTensor) -> Result<()> {
93 let gr = file.create_group("/tensors")?;
94 let tensor_dataset = gr.new_dataset_builder().with_data(tensor);
95 tensor_dataset.create("-1")?;
96 file.flush()
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 use std::iter::zip;
104
105 use approx::assert_abs_diff_eq;
106 use hdf5_metno::{AttributeBuilder, File, Result};
107 use ndarray::{array, Array2};
108 use num_complex::Complex64;
109 use rand::{
110 distr::{Alphanumeric, SampleString},
111 rng,
112 };
113
114 use crate::tensornetwork::tensor::Tensor;
115 use crate::tensornetwork::tensordata::TensorData;
116
117 fn new_in_memory_file() -> Result<File> {
121 let random_filename = Alphanumeric.sample_string(&mut rng(), 8);
122 File::with_options()
123 .with_access_plist(|p| p.core_filebacked(false))
124 .create(random_filename)
125 }
126
127 fn create_hdf5_tensor() -> Result<File> {
128 let new_file = new_in_memory_file()?;
129 let tensor_group = new_file.create_group("./tensors")?;
130 let dataset_builder = tensor_group.new_dataset_builder();
131 let dataset = dataset_builder.empty::<Complex64>().create("-1")?;
132 let attribute = AttributeBuilder::new(&dataset);
133 let bid = array![0, 1];
134 let attribute = attribute.with_data(&bid);
135 attribute.create("bids")?;
136
137 let data = Array2::<Complex64>::from_shape_vec(
138 (2, 2),
139 vec![
140 Complex64::new(1.0, 0.0),
141 Complex64::new(0.0, 2.0),
142 Complex64::new(3.0, 0.0),
143 Complex64::new(0.0, 1.0),
144 ],
145 )?
146 .into_dyn();
147 let dataset_builder2 = tensor_group.new_dataset_builder();
148 let dataset_data_builder2 = dataset_builder2.with_data(&data);
149 let dataset2 = dataset_data_builder2.create("0")?;
150 let attribute2 = AttributeBuilder::new(&dataset2);
151 let bid2 = array![0, 1];
152 let attribute2 = attribute2.with_data(&bid2);
153 attribute2.create("bids")?;
154
155 new_file.flush()?;
156 Ok(new_file)
157 }
158
159 fn create_hdf5_data() -> Result<File> {
160 let new_file = new_in_memory_file()?;
161 let tensor_group = new_file.create_group("./tensors")?;
162 let dataset_builder = tensor_group.new_dataset_builder();
163 let data = Array2::<Complex64>::from_shape_vec(
164 (2, 2),
165 vec![
166 Complex64::new(1.0, 0.0),
167 Complex64::new(0.0, 2.0),
168 Complex64::new(3.0, 0.0),
169 Complex64::new(0.0, 1.0),
170 ],
171 )?
172 .into_dyn();
173 let dataset_data_builder = dataset_builder.with_data(&data);
174 dataset_data_builder.create("-1")?;
175 new_file.flush()?;
176 Ok(new_file)
177 }
178
179 #[test]
180 fn test_load_data() {
181 let file = create_hdf5_data().unwrap();
182 let tensor_data = read_data(&file).unwrap();
183
184 let ref_data = array![
185 Complex64::new(1.0, 0.0),
186 Complex64::new(0.0, 2.0),
187 Complex64::new(3.0, 0.0),
188 Complex64::new(0.0, 1.0),
189 ];
190 for (u, v) in zip(ref_data.iter(), tensor_data.flatten().iter()) {
191 assert_abs_diff_eq!(u.re, v.re, epsilon = 1e-8);
192 assert_abs_diff_eq!(u.im, v.im, epsilon = 1e-8);
193 }
194 }
195
196 #[test]
197 fn test_load_tensor() {
198 let file = create_hdf5_tensor().unwrap();
199 let tensor = read_tensor(&file).unwrap();
200
201 let mut ref_tn = Tensor::default();
202 let mut ref_tensor = Tensor::new(vec![0, 1], vec![2, 2]);
203 ref_tensor.set_tensor_data(TensorData::new_from_data(
204 &[2, 2],
205 vec![
206 Complex64::new(1.0, 0.0),
207 Complex64::new(0.0, 2.0),
208 Complex64::new(3.0, 0.0),
209 Complex64::new(0.0, 1.0),
210 ],
211 ));
212 ref_tn.push_tensor(ref_tensor);
213 ref_tn.set_legs(vec![0, 1]);
214 assert_abs_diff_eq!(&tensor, &ref_tn);
215 }
216
217 #[test]
218 fn test_write_read() {
219 let file = new_in_memory_file().unwrap();
220 let data = vec![
221 Complex64::new(1.0, 0.0),
222 Complex64::new(0.0, -2.0),
223 Complex64::new(-3.0, 0.0),
224 Complex64::new(-2.0, -1.0),
225 Complex64::new(0.0, 0.0),
226 Complex64::new(0.5, 2.0),
227 ];
228 let tensor = Array2::<Complex64>::from_shape_vec((2, 3), data)
229 .unwrap()
230 .into_dyn();
231
232 write_data(&file, &tensor).unwrap();
233 let read = read_data(&file).unwrap();
234
235 assert_abs_diff_eq!(&tensor, &read);
236 }
237}