Skip to main content

tnc/io/qasm/
qasm_importer.rs

1use crate::builders::circuit_builder::Circuit;
2use crate::io::qasm::{
3    ast::Visitor, circuit_creator::CircuitCreator, expression_folder::ExpressionFolder,
4    gate_inliner::GateInliner, include_resolver::expand_includes, parser::parse,
5};
6
7/// Creates a [`Circuit`] from OpenQASM2 code.
8///
9/// All gates are inlined up to the known gates defined in [`crate::gates`]. All
10/// qubits are initialized to zero. Note that not all QASM instructions are
11/// supported, such as `measure` or `if`.
12pub fn import_qasm<S>(code: S) -> Circuit
13where
14    S: Into<String>,
15{
16    // Expand all includes
17    let mut full_code = code.into();
18    expand_includes(&mut full_code);
19
20    // Parse to AST
21    let mut program = parse(&full_code);
22
23    // Simplify expressions (not strictly needed)
24    let mut expression_folder = ExpressionFolder;
25    expression_folder.visit_program(&mut program);
26
27    // Inline gate calls
28    let mut inliner = GateInliner::default();
29    inliner.inline_program(&mut program);
30
31    // Simplify expressions after inline (needed)
32    let mut expression_folder = ExpressionFolder;
33    expression_folder.visit_program(&mut program);
34
35    // Create the circuit
36    let circuit_creator = CircuitCreator;
37    circuit_creator.create_circuit(&program)
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43
44    use std::f64::consts::FRAC_1_SQRT_2;
45
46    use approx::assert_abs_diff_eq;
47    use num_complex::Complex64;
48
49    use crate::{
50        builders::circuit_builder::Permutor,
51        contractionpath::ContractionPath,
52        tensornetwork::{
53            contraction::contract_tensor_network,
54            tensor::{EdgeIndex, Tensor, TensorIndex},
55            tensordata::TensorData,
56        },
57    };
58
59    /// Returns whether the edge connects the two tensors.
60    fn edge_connects(
61        edge_id: EdgeIndex,
62        t1_id: TensorIndex,
63        t2_id: TensorIndex,
64        tn: &Tensor,
65    ) -> bool {
66        let overlap = tn.tensor(t1_id) & tn.tensor(t2_id);
67        overlap.legs().contains(&edge_id)
68    }
69
70    /// Returns whether the edge is an open edge of the tensor.
71    fn is_open_edge_of(edge_id: EdgeIndex, t1_id: TensorIndex, tn: &Tensor) -> bool {
72        // Check if the edge is a leg of the tensor
73        if !tn.tensor(t1_id).legs().contains(&edge_id) {
74            return false;
75        }
76
77        // Check if the edge is not connected to any other tensor
78        for (tensor_id, tensor) in tn.tensors().iter().enumerate() {
79            if tensor_id != t1_id && tensor.legs().contains(&edge_id) {
80                return false;
81            }
82        }
83        true
84    }
85
86    struct IdTensor<'a> {
87        id: usize,
88        tensor: &'a Tensor,
89    }
90
91    fn get_quantum_tensors(
92        tn: &Tensor,
93    ) -> (Vec<IdTensor<'_>>, Vec<IdTensor<'_>>, Vec<IdTensor<'_>>) {
94        let mut kets = Vec::new();
95        let mut single_qubit_gates = Vec::new();
96        let mut two_qubit_gates = Vec::new();
97        for (tid, tensor) in tn.tensors().iter().enumerate() {
98            let id: usize = tid;
99            let legs = tensor.legs().len();
100            match legs {
101                1 => kets.push(IdTensor { id, tensor }),
102                2 => single_qubit_gates.push(IdTensor { id, tensor }),
103                4 => two_qubit_gates.push(IdTensor { id, tensor }),
104                _ => panic!("Tensor with unexpected leg count {legs} in quantum tensor network"),
105            }
106        }
107        (kets, single_qubit_gates, two_qubit_gates)
108    }
109
110    #[test]
111    fn bell_tensornetwork_construction() {
112        let code = "OPENQASM 2.0;
113        include \"qelib1.inc\";
114        qreg q[2];
115        h q[0];
116        cx q[0], q[1];
117        ";
118        let circuit = import_qasm(code);
119        let (tn, _) = circuit.into_statevector_network();
120
121        let (kets, single_qubit_gates, two_qubit_gates) = get_quantum_tensors(&tn);
122        let [k0, k1] = kets.as_slice() else { panic!() };
123        let [h] = single_qubit_gates.as_slice() else {
124            panic!()
125        };
126        let [cx] = two_qubit_gates.as_slice() else {
127            panic!()
128        };
129
130        // Find out which tensor is the first/top qubit (the one connected to the H gate tensor)
131        // and which is the second/bottom qubit
132        let first_qubit_id = h.tensor.legs()[1];
133        let (first_qubit, second_qubit) = if first_qubit_id == k0.id {
134            (k0, k1)
135        } else if first_qubit_id == k1.id {
136            (k1, k0)
137        } else {
138            panic!("H gate tensor not connected to any ket tensor");
139        };
140
141        // Check edges
142        let fq_to_h_id = first_qubit.tensor.legs()[0];
143        assert_eq!(h.tensor.legs()[1], fq_to_h_id);
144        assert!(edge_connects(fq_to_h_id, first_qubit_id, h.id, &tn));
145
146        let sq_to_cx_t_id = second_qubit.tensor.legs()[0];
147        assert_eq!(cx.tensor.legs()[3], sq_to_cx_t_id);
148        assert!(edge_connects(sq_to_cx_t_id, second_qubit.id, cx.id, &tn));
149
150        let h_to_cx_c_id = h.tensor.legs()[0];
151        assert_eq!(cx.tensor.legs()[2], h_to_cx_c_id);
152        assert!(edge_connects(h_to_cx_c_id, h.id, cx.id, &tn));
153
154        let cx_c_to_open_id = cx.tensor.legs()[0];
155        assert!(is_open_edge_of(cx_c_to_open_id, cx.id, &tn));
156
157        let cx_t_to_open_id = cx.tensor.legs()[1];
158        assert!(is_open_edge_of(cx_t_to_open_id, cx.id, &tn));
159    }
160
161    /// Contracts the tensor network with an arbitrary contraction order, then
162    /// returns the correctly permuted tensor data.
163    fn contract_tn(tn: Tensor, perm: &Permutor) -> TensorData {
164        let opt_path =
165            ContractionPath::simple((1..tn.tensors().len()).map(|tid| (0, tid)).collect());
166        let tn = contract_tensor_network(tn, &opt_path);
167        let mut tn = perm.apply(tn);
168        std::mem::take(&mut tn.tensordata)
169    }
170
171    #[test]
172    fn bell_contract() {
173        let code = "OPENQASM 2.0;
174        include \"qelib1.inc\";
175        qreg q[2];
176        h q[0];
177        cx q[0], q[1];
178        ";
179        let circuit = import_qasm(code);
180        let (tn, perm) = circuit.into_statevector_network();
181        let resulting_state = contract_tn(tn, &perm);
182
183        let expected = TensorData::new_from_data(
184            &[2, 2],
185            vec![
186                Complex64::new(FRAC_1_SQRT_2, 0.),
187                Complex64::ZERO,
188                Complex64::ZERO,
189                Complex64::new(FRAC_1_SQRT_2, 0.),
190            ],
191        );
192        assert_abs_diff_eq!(&resulting_state, &expected);
193    }
194
195    #[test]
196    fn custom_swap() {
197        let code = "OPENQASM 2.0;
198        include \"qelib1.inc\";
199        qreg q[2];
200        gate myswap a, b {
201            cx a, b;
202            cx b, a;
203            cx a, b;
204        }
205        x q[0];
206        myswap q[1], q[0];
207        ";
208        let circuit = import_qasm(code);
209        let (tn, perm) = circuit.into_statevector_network();
210        let resulting_state = contract_tn(tn, &perm);
211
212        let expected = TensorData::new_from_data(
213            &[2, 2],
214            vec![
215                Complex64::ZERO,
216                Complex64::ONE,
217                Complex64::ZERO,
218                Complex64::ZERO,
219            ],
220        );
221        assert_abs_diff_eq!(&resulting_state, &expected);
222    }
223
224    fn odd_test_circuit() -> Circuit {
225        // Test with odd numbers to check the order of the statevector is correct
226        let code = "OPENQASM 2.0;
227        include \"qelib1.inc\";
228        qreg q[3];
229        rx(0.5) q[0];
230        rx(0.2) q[1];
231        rx(0.3) q[2];
232        cx q[0], q[1];
233        cx q[1], q[2];";
234        import_qasm(code)
235    }
236
237    #[test]
238    fn statevector_order() {
239        let circuit = odd_test_circuit();
240        let (tn, perm) = circuit.into_statevector_network();
241        let resulting_state = contract_tn(tn, &perm);
242
243        let expected = TensorData::new_from_data(
244            &[2, 2, 2],
245            vec![
246                Complex64::new(0.953246407214305, 0.0),
247                Complex64::new(0.0, -0.14406910361762032),
248                Complex64::new(-0.014455126269118733, 0.0),
249                Complex64::new(0.0, -0.09564366568448116),
250                Complex64::new(-0.024421837348497916, 0.0),
251                Complex64::new(0.0, 0.0036909997130494475),
252                Complex64::new(-0.03678688170631573, 0.0),
253                Complex64::new(0.0, -0.24340376901515096),
254            ],
255        );
256        assert_abs_diff_eq!(&resulting_state, &expected);
257    }
258
259    #[test]
260    fn statevector_order_two_fixed_qubits() {
261        let circuit = odd_test_circuit();
262        // 1*0 should get a vec with amplitudes |100> and |110>
263        let (tn, perm) = circuit.into_amplitude_network("1*0");
264        let resulting_state = contract_tn(tn, &perm);
265
266        let expected = TensorData::new_from_data(
267            &[2],
268            vec![
269                Complex64::new(-0.024421837348497916, 0.0),
270                Complex64::new(-0.03678688170631573, 0.0),
271            ],
272        );
273        assert_abs_diff_eq!(&resulting_state, &expected);
274    }
275
276    #[test]
277    fn statevector_order_one_fixed_qubit() {
278        let circuit = odd_test_circuit();
279        // *1* should get a vec with amplitudes |010>, |011>, |110>, |111>
280        let (tn, perm) = circuit.into_amplitude_network("*1*");
281        let resulting_state = contract_tn(tn, &perm);
282
283        let expected = TensorData::new_from_data(
284            &[2, 2],
285            vec![
286                Complex64::new(-0.014455126269118733, 0.0),
287                Complex64::new(0.0, -0.09564366568448116),
288                Complex64::new(-0.03678688170631573, 0.0),
289                Complex64::new(0.0, -0.24340376901515096),
290            ],
291        );
292        assert_abs_diff_eq!(&resulting_state, &expected);
293    }
294
295    #[test]
296    fn gate_order() {
297        // Ensures that the cx gate legs are in the correct order
298        let code = "OPENQASM 2.0;
299        include \"qelib1.inc\";
300        qreg q[2];
301        creg c[1];
302        u2(0,0) q[0];
303        u2(-pi,-pi) q[1];
304        cx q[0],q[1];
305        u2(-pi,-pi) q[0];";
306
307        let circuit = import_qasm(code);
308        let (tn, perm) = circuit.into_statevector_network();
309        let resulting_state = contract_tn(tn, &perm);
310
311        let expected = TensorData::new_from_data(
312            &[2, 2],
313            vec![
314                Complex64::ZERO,
315                Complex64::ZERO,
316                Complex64::new(-FRAC_1_SQRT_2, 0.0),
317                Complex64::new(FRAC_1_SQRT_2, 0.0),
318            ],
319        );
320        assert_abs_diff_eq!(&resulting_state, &expected);
321    }
322}