Skip to main content

tnc/io/
hdf5.rs

1//! Import and export of tensors or tensor networks as HDF5 files.
2//!
3//! The files follow this structure:
4//! ```text
5//! tensors/
6//!     tensor: n-dimensional dataset
7//!         attrs:
8//!             - bids
9//!             - tids
10//! ```
11//!  There is a single `tensors/` group containing multiple tensor datasets. Each
12//! `tensor` is a flattened tensor with dimensions `shape`. The `tid` is the unique
13//! positive integer used to identify each tensor, with the output tensor, identified
14//! by `-1`, containing output bond dimensions and no tensor data. The `bids` are a
15//! list of integers corresponding to the bond ids of in each tensor.
16
17use std::path::Path;
18
19use hdf5_metno::{File, Result};
20use num_complex::Complex64;
21
22use crate::tensornetwork::tensor::Tensor;
23use crate::tensornetwork::tensordata::{DataTensor, TensorData};
24
25/// Loads a tensor network from a HDF5 file.
26pub fn load_tensor<P>(filename: P) -> Result<Tensor>
27where
28    P: AsRef<Path>,
29{
30    let file = File::open(filename)?;
31    read_tensor(&file)
32}
33
34/// Loads a single tensor from a HDF5 file.
35pub fn load_data<P>(filename: P) -> Result<DataTensor>
36where
37    P: AsRef<Path>,
38{
39    let file = File::open(filename)?;
40    read_data(&file)
41}
42
43/// Stores a single tensor in a HDF5 file.
44pub fn store_data<P>(filename: P, tensor: &DataTensor) -> Result<()>
45where
46    P: AsRef<Path>,
47{
48    let file = File::create(filename)?;
49    write_data(&file, tensor)
50}
51
52fn read_tensor(file: &File) -> Result<Tensor> {
53    let gr = file.group("/tensors")?;
54    let tensor_names = gr.member_names()?;
55
56    // Outuput tensor is always labelled as -1
57    let out_tensor = gr.dataset("-1")?;
58    let out_tensor_bids = out_tensor.attr("bids")?;
59    let out_bond_ids = out_tensor_bids.read_1d::<usize>()?;
60
61    let mut new_tensor_network = Tensor::default();
62
63    for tensor_name in tensor_names {
64        if tensor_name == "-1" {
65            continue;
66        }
67        let tensor = gr.dataset(&tensor_name)?;
68        let bond_ids = tensor.attr("bids").unwrap().read_1d::<usize>()?;
69        let tensor_dataset = gr.dataset(&tensor_name).unwrap().read_dyn::<Complex64>()?;
70        let tensor_shape = tensor_dataset.shape();
71        let bond_dims = tensor_shape.iter().map(|s| *s as u64).collect();
72        let mut new_tensor = Tensor::new(bond_ids.to_vec(), bond_dims);
73        new_tensor.set_tensor_data(TensorData::Matrix(tensor_dataset));
74        new_tensor_network.push_tensor(new_tensor);
75    }
76    new_tensor_network.set_legs(out_bond_ids.to_vec());
77
78    Ok(new_tensor_network)
79}
80
81fn read_data(file: &File) -> Result<DataTensor> {
82    let gr = file.group("/tensors")?;
83    let tensor_name = gr.member_names()?;
84
85    let tensor_dataset = gr
86        .dataset(&tensor_name[0])
87        .unwrap()
88        .read_dyn::<Complex64>()?;
89    Ok(tensor_dataset)
90}
91
92fn write_data(file: &File, tensor: &DataTensor) -> Result<()> {
93    let gr = file.create_group("/tensors")?;
94    let tensor_dataset = gr.new_dataset_builder().with_data(tensor);
95    tensor_dataset.create("-1")?;
96    file.flush()
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    use std::iter::zip;
104
105    use approx::assert_abs_diff_eq;
106    use hdf5_metno::{AttributeBuilder, File, Result};
107    use ndarray::{array, Array2};
108    use num_complex::Complex64;
109    use rand::{
110        distr::{Alphanumeric, SampleString},
111        rng,
112    };
113
114    use crate::tensornetwork::tensor::Tensor;
115    use crate::tensornetwork::tensordata::TensorData;
116
117    /// Creates a new HDF5 file in memory.
118    /// This method is taken from the hdf5 crate integration tests:
119    /// <https://github.com/aldanor/hdf5-rust/blob/694e900972fbf5ffbdd1a2294f57a2cc3a91c994/hdf5/tests/common/util.rs#L7>.
120    fn new_in_memory_file() -> Result<File> {
121        let random_filename = Alphanumeric.sample_string(&mut rng(), 8);
122        File::with_options()
123            .with_access_plist(|p| p.core_filebacked(false))
124            .create(random_filename)
125    }
126
127    fn create_hdf5_tensor() -> Result<File> {
128        let new_file = new_in_memory_file()?;
129        let tensor_group = new_file.create_group("./tensors")?;
130        let dataset_builder = tensor_group.new_dataset_builder();
131        let dataset = dataset_builder.empty::<Complex64>().create("-1")?;
132        let attribute = AttributeBuilder::new(&dataset);
133        let bid = array![0, 1];
134        let attribute = attribute.with_data(&bid);
135        attribute.create("bids")?;
136
137        let data = Array2::<Complex64>::from_shape_vec(
138            (2, 2),
139            vec![
140                Complex64::new(1.0, 0.0),
141                Complex64::new(0.0, 2.0),
142                Complex64::new(3.0, 0.0),
143                Complex64::new(0.0, 1.0),
144            ],
145        )?
146        .into_dyn();
147        let dataset_builder2 = tensor_group.new_dataset_builder();
148        let dataset_data_builder2 = dataset_builder2.with_data(&data);
149        let dataset2 = dataset_data_builder2.create("0")?;
150        let attribute2 = AttributeBuilder::new(&dataset2);
151        let bid2 = array![0, 1];
152        let attribute2 = attribute2.with_data(&bid2);
153        attribute2.create("bids")?;
154
155        new_file.flush()?;
156        Ok(new_file)
157    }
158
159    fn create_hdf5_data() -> Result<File> {
160        let new_file = new_in_memory_file()?;
161        let tensor_group = new_file.create_group("./tensors")?;
162        let dataset_builder = tensor_group.new_dataset_builder();
163        let data = Array2::<Complex64>::from_shape_vec(
164            (2, 2),
165            vec![
166                Complex64::new(1.0, 0.0),
167                Complex64::new(0.0, 2.0),
168                Complex64::new(3.0, 0.0),
169                Complex64::new(0.0, 1.0),
170            ],
171        )?
172        .into_dyn();
173        let dataset_data_builder = dataset_builder.with_data(&data);
174        dataset_data_builder.create("-1")?;
175        new_file.flush()?;
176        Ok(new_file)
177    }
178
179    #[test]
180    fn test_load_data() {
181        let file = create_hdf5_data().unwrap();
182        let tensor_data = read_data(&file).unwrap();
183
184        let ref_data = array![
185            Complex64::new(1.0, 0.0),
186            Complex64::new(0.0, 2.0),
187            Complex64::new(3.0, 0.0),
188            Complex64::new(0.0, 1.0),
189        ];
190        for (u, v) in zip(ref_data.iter(), tensor_data.flatten().iter()) {
191            assert_abs_diff_eq!(u.re, v.re, epsilon = 1e-8);
192            assert_abs_diff_eq!(u.im, v.im, epsilon = 1e-8);
193        }
194    }
195
196    #[test]
197    fn test_load_tensor() {
198        let file = create_hdf5_tensor().unwrap();
199        let tensor = read_tensor(&file).unwrap();
200
201        let mut ref_tn = Tensor::default();
202        let mut ref_tensor = Tensor::new(vec![0, 1], vec![2, 2]);
203        ref_tensor.set_tensor_data(TensorData::new_from_data(
204            &[2, 2],
205            vec![
206                Complex64::new(1.0, 0.0),
207                Complex64::new(0.0, 2.0),
208                Complex64::new(3.0, 0.0),
209                Complex64::new(0.0, 1.0),
210            ],
211        ));
212        ref_tn.push_tensor(ref_tensor);
213        ref_tn.set_legs(vec![0, 1]);
214        assert_abs_diff_eq!(&tensor, &ref_tn);
215    }
216
217    #[test]
218    fn test_write_read() {
219        let file = new_in_memory_file().unwrap();
220        let data = vec![
221            Complex64::new(1.0, 0.0),
222            Complex64::new(0.0, -2.0),
223            Complex64::new(-3.0, 0.0),
224            Complex64::new(-2.0, -1.0),
225            Complex64::new(0.0, 0.0),
226            Complex64::new(0.5, 2.0),
227        ];
228        let tensor = Array2::<Complex64>::from_shape_vec((2, 3), data)
229            .unwrap()
230            .into_dyn();
231
232        write_data(&file, &tensor).unwrap();
233        let read = read_data(&file).unwrap();
234
235        assert_abs_diff_eq!(&tensor, &read);
236    }
237}