tnc/qasm/
qasm_importer.rs

1use crate::builders::circuit_builder::Circuit;
2use crate::qasm::{
3    ast::Visitor, circuit_creator::CircuitCreator, expression_folder::ExpressionFolder,
4    gate_inliner::GateInliner, include_resolver::expand_includes, parser::parse,
5};
6
7/// Creates a [`Circuit`] from OpenQASM2 code.
8///
9/// All gates are inlined up to the known gates defined in [`crate::gates`]. All
10/// qubits are initialized to zero. Note that not all QASM instructions are
11/// supported, such as `measure` or `if`.
12pub fn import_qasm<S>(code: S) -> Circuit
13where
14    S: Into<String>,
15{
16    // Expand all includes
17    let mut full_code = code.into();
18    expand_includes(&mut full_code);
19
20    // Parse to AST
21    let mut program = parse(&full_code);
22
23    // Simplify expressions (not strictly needed)
24    let mut expression_folder = ExpressionFolder;
25    expression_folder.visit_program(&mut program);
26
27    // Inline gate calls
28    let mut inliner = GateInliner::default();
29    inliner.inline_program(&mut program);
30
31    // Simplify expressions after inline (needed)
32    let mut expression_folder = ExpressionFolder;
33    expression_folder.visit_program(&mut program);
34
35    // Create the circuit
36    let circuit_creator = CircuitCreator;
37    circuit_creator.create_circuit(&program)
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43
44    use std::f64::consts::FRAC_1_SQRT_2;
45
46    use float_cmp::assert_approx_eq;
47    use num_complex::{c64, Complex64};
48
49    use crate::{
50        builders::circuit_builder::Permutor,
51        contractionpath::ContractionPath,
52        tensornetwork::{
53            contraction::contract_tensor_network,
54            tensor::{EdgeIndex, Tensor, TensorIndex},
55            tensordata::TensorData,
56        },
57    };
58
59    /// Returns whether the edge connects the two tensors.
60    fn edge_connects(
61        edge_id: EdgeIndex,
62        t1_id: TensorIndex,
63        t2_id: TensorIndex,
64        tn: &Tensor,
65    ) -> bool {
66        let overlap = tn.tensor(t1_id) & tn.tensor(t2_id);
67        overlap.legs().contains(&edge_id)
68    }
69
70    /// Returns whether the edge is an open edge of the tensor.
71    fn is_open_edge_of(edge_id: EdgeIndex, t1_id: TensorIndex, tn: &Tensor) -> bool {
72        // Check if the edge is a leg of the tensor
73        if !tn.tensor(t1_id).legs().contains(&edge_id) {
74            return false;
75        }
76
77        // Check if the edge is not connected to any other tensor
78        for (tensor_id, tensor) in tn.tensors().iter().enumerate() {
79            if tensor_id != t1_id && tensor.legs().contains(&edge_id) {
80                return false;
81            }
82        }
83        true
84    }
85
86    struct IdTensor<'a> {
87        id: usize,
88        tensor: &'a Tensor,
89    }
90
91    fn get_quantum_tensors(
92        tn: &Tensor,
93    ) -> (Vec<IdTensor<'_>>, Vec<IdTensor<'_>>, Vec<IdTensor<'_>>) {
94        let mut kets = Vec::new();
95        let mut single_qubit_gates = Vec::new();
96        let mut two_qubit_gates = Vec::new();
97        for (tid, tensor) in tn.tensors().iter().enumerate() {
98            let id: usize = tid;
99            let legs = tensor.legs().len();
100            match legs {
101                1 => kets.push(IdTensor { id, tensor }),
102                2 => single_qubit_gates.push(IdTensor { id, tensor }),
103                4 => two_qubit_gates.push(IdTensor { id, tensor }),
104                _ => panic!("Tensor with unexpected leg count {legs} in quantum tensor network"),
105            }
106        }
107        (kets, single_qubit_gates, two_qubit_gates)
108    }
109
110    #[test]
111    fn bell_tensornetwork_construction() {
112        let code = "OPENQASM 2.0;
113        include \"qelib1.inc\";
114        qreg q[2];
115        h q[0];
116        cx q[0], q[1];
117        ";
118        let circuit = import_qasm(code);
119        let (tn, _) = circuit.into_statevector_network();
120
121        let (kets, single_qubit_gates, two_qubit_gates) = get_quantum_tensors(&tn);
122        let [k0, k1] = kets.as_slice() else { panic!() };
123        let [h] = single_qubit_gates.as_slice() else {
124            panic!()
125        };
126        let [cx] = two_qubit_gates.as_slice() else {
127            panic!()
128        };
129
130        // Find out which tensor is the first/top qubit (the one connected to the H gate tensor)
131        // and which is the second/bottom qubit
132        let first_qubit_id = h.tensor.legs()[0];
133        let (first_qubit, second_qubit) = if first_qubit_id == k0.id {
134            (k0, k1)
135        } else if first_qubit_id == k1.id {
136            (k1, k0)
137        } else {
138            panic!("H gate tensor not connected to any ket tensor");
139        };
140
141        // Check edges
142        let fq_to_h_id = first_qubit.tensor.legs()[0];
143        assert_eq!(h.tensor.legs()[0], fq_to_h_id);
144        assert!(edge_connects(fq_to_h_id, first_qubit_id, h.id, &tn));
145
146        let sq_to_cx_t_id = second_qubit.tensor.legs()[0];
147        assert_eq!(cx.tensor.legs()[1], sq_to_cx_t_id);
148        assert!(edge_connects(sq_to_cx_t_id, second_qubit.id, cx.id, &tn));
149
150        let h_to_cx_c_id = h.tensor.legs()[1];
151        assert_eq!(cx.tensor.legs()[0], h_to_cx_c_id);
152        assert!(edge_connects(h_to_cx_c_id, h.id, cx.id, &tn));
153
154        let cx_c_to_open_id = cx.tensor.legs()[2];
155        assert!(is_open_edge_of(cx_c_to_open_id, cx.id, &tn));
156
157        let cx_t_to_open_id = cx.tensor.legs()[3];
158        assert!(is_open_edge_of(cx_t_to_open_id, cx.id, &tn));
159    }
160
161    /// Contracts the tensor network with an arbitrary contraction order, then
162    /// returns the correctly permuted tensor data.
163    fn contract_tn(tn: Tensor, perm: &Permutor) -> TensorData {
164        let opt_path =
165            ContractionPath::simple((1..tn.tensors().len()).map(|tid| (0, tid)).collect());
166        let tn = contract_tensor_network(tn, &opt_path);
167        let mut tn = perm.apply(tn);
168        std::mem::take(&mut tn.tensordata)
169    }
170
171    #[test]
172    fn bell_contract() {
173        let code = "OPENQASM 2.0;
174        include \"qelib1.inc\";
175        qreg q[2];
176        h q[0];
177        cx q[0], q[1];
178        ";
179        let circuit = import_qasm(code);
180        let (tn, perm) = circuit.into_statevector_network();
181        let resulting_state = contract_tn(tn, &perm);
182
183        let expected = TensorData::new_from_data(
184            &[2, 2],
185            vec![
186                c64(FRAC_1_SQRT_2, 0.),
187                c64(0, 0),
188                c64(0, 0),
189                c64(FRAC_1_SQRT_2, 0.),
190            ],
191            None,
192        );
193        assert_approx_eq!(&TensorData, &resulting_state, &expected);
194    }
195
196    #[test]
197    fn custom_swap() {
198        let code = "OPENQASM 2.0;
199        include \"qelib1.inc\";
200        qreg q[2];
201        gate myswap a, b {
202            cx a, b;
203            cx b, a;
204            cx a, b;
205        }
206        x q[0];
207        myswap q[1], q[0];
208        ";
209        let circuit = import_qasm(code);
210        let (tn, perm) = circuit.into_statevector_network();
211        let resulting_state = contract_tn(tn, &perm);
212
213        let expected = TensorData::new_from_data(
214            &[2, 2],
215            vec![
216                Complex64::ZERO,
217                Complex64::ONE,
218                Complex64::ZERO,
219                Complex64::ZERO,
220            ],
221            None,
222        );
223        assert_approx_eq!(&TensorData, &resulting_state, &expected);
224    }
225
226    fn odd_test_circuit() -> Circuit {
227        // Test with odd numbers to check the order of the statevector is correct
228        let code = "OPENQASM 2.0;
229        include \"qelib1.inc\";
230        qreg q[3];
231        rx(0.5) q[0];
232        rx(0.2) q[1];
233        rx(0.3) q[2];
234        cx q[0], q[1];
235        cx q[1], q[2];";
236        import_qasm(code)
237    }
238
239    #[test]
240    fn statevector_order() {
241        let circuit = odd_test_circuit();
242        let (tn, perm) = circuit.into_statevector_network();
243        let resulting_state = contract_tn(tn, &perm);
244
245        let expected = TensorData::new_from_data(
246            &[2, 2, 2],
247            vec![
248                Complex64::new(0.953246407214305, 0.0),
249                Complex64::new(0.0, -0.14406910361762032),
250                Complex64::new(-0.014455126269118733, 0.0),
251                Complex64::new(0.0, -0.09564366568448116),
252                Complex64::new(-0.024421837348497916, 0.0),
253                Complex64::new(0.0, 0.0036909997130494475),
254                Complex64::new(-0.03678688170631573, 0.0),
255                Complex64::new(0.0, -0.24340376901515096),
256            ],
257            None,
258        );
259        assert_approx_eq!(&TensorData, &resulting_state, &expected);
260    }
261
262    #[test]
263    fn statevector_order_two_fixed_qubits() {
264        let circuit = odd_test_circuit();
265        // 1*0 should get a vec with amplitudes |100> and |110>
266        let (tn, perm) = circuit.into_amplitude_network("1*0");
267        let resulting_state = contract_tn(tn, &perm);
268
269        let expected = TensorData::new_from_data(
270            &[2],
271            vec![
272                Complex64::new(-0.024421837348497916, 0.0),
273                Complex64::new(-0.03678688170631573, 0.0),
274            ],
275            None,
276        );
277        assert_approx_eq!(&TensorData, &resulting_state, &expected);
278    }
279
280    #[test]
281    fn statevector_order_one_fixed_qubit() {
282        let circuit = odd_test_circuit();
283        // *1* should get a vec with amplitudes |010>, |011>, |110>, |111>
284        let (tn, perm) = circuit.into_amplitude_network("*1*");
285        let resulting_state = contract_tn(tn, &perm);
286
287        let expected = TensorData::new_from_data(
288            &[2, 2],
289            vec![
290                Complex64::new(-0.014455126269118733, 0.0),
291                Complex64::new(0.0, -0.09564366568448116),
292                Complex64::new(-0.03678688170631573, 0.0),
293                Complex64::new(0.0, -0.24340376901515096),
294            ],
295            None,
296        );
297        assert_approx_eq!(&TensorData, &resulting_state, &expected);
298    }
299}