Skip to main content

tnc/
gates.rs

1//! Implementation of various quantum gates to be used in tensor networks.
2
3use std::{
4    borrow::Borrow,
5    f64::consts::FRAC_1_SQRT_2,
6    hash::{Hash, Hasher},
7    sync::{LazyLock, RwLock},
8};
9
10use itertools::Itertools;
11use ndarray::array;
12use num_complex::Complex64;
13use rustc_hash::FxHashSet;
14
15use crate::{tensornetwork::tensordata::DataTensor, utils::traits::ConjugateInPlace};
16
17static GATES: LazyLock<RwLock<FxHashSet<Box<dyn Gate>>>> = LazyLock::new(|| {
18    let mut gates = FxHashSet::default();
19    gates.insert(Box::new(X) as _);
20    gates.insert(Box::new(Y) as _);
21    gates.insert(Box::new(Z) as _);
22    gates.insert(Box::new(H) as _);
23    gates.insert(Box::new(T) as _);
24    gates.insert(Box::new(U) as _);
25    gates.insert(Box::new(Sx) as _);
26    gates.insert(Box::new(Sy) as _);
27    gates.insert(Box::new(Sz) as _);
28    gates.insert(Box::new(Rx) as _);
29    gates.insert(Box::new(Ry) as _);
30    gates.insert(Box::new(Rz) as _);
31    gates.insert(Box::new(Cx) as _);
32    gates.insert(Box::new(Cz) as _);
33    gates.insert(Box::new(Swap) as _);
34    gates.insert(Box::new(Cp) as _);
35    gates.insert(Box::new(Iswap) as _);
36    gates.insert(Box::new(Fsim) as _);
37    RwLock::new(gates)
38});
39
40/// Registers a gate definition to resolve a gate name to a gate implementation.
41pub fn register_gate(gate: Box<dyn Gate>) {
42    assert!(
43        gate.name().to_ascii_lowercase() == gate.name(),
44        "Gate name must be lowercase."
45    );
46    GATES.write().unwrap().insert(gate);
47}
48
49/// Computes the gate matrix for the given gate and angles.
50#[must_use]
51pub fn load_gate(gate: &str, angles: &[f64]) -> DataTensor {
52    let gates = &GATES.read().unwrap();
53    let gate = gates
54        .get(gate)
55        .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
56    gate.compute(angles)
57}
58
59/// Computes the adjoint of the gate matrix for the given gate and angles.
60#[must_use]
61pub fn load_gate_adjoint(gate: &str, angles: &[f64]) -> DataTensor {
62    let gates = &GATES.read().unwrap();
63    let gate = gates
64        .get(gate)
65        .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
66    gate.adjoint(angles)
67}
68
69/// Returns whether the given gate is known.
70#[must_use]
71pub fn is_gate_known(gate: &str) -> bool {
72    let gates = &GATES.read().unwrap();
73    gates.contains(gate)
74}
75
76/// Helper method to compute the transpose of a matrix-like data tensor in-place. The
77/// data tensor can be of shape `(2^n, 2^n)`, or also be split in `2n` dimensions of
78/// size `2`, like `(2,2,2,...)`. In the second case, the transpose is computed by
79/// swapping the first half of the dimensions with the second half.
80///
81/// For example, both `(8,8)` or `(2,2,2,2,2,2)` are okay. If given `(2,2,2,2,2,2)`,
82/// the permutation applied will be `(3,4,5,0,1,2)`.
83fn matrix_transpose_inplace(data: &mut DataTensor) {
84    if data.ndim() > 0 {
85        assert!(data.ndim().is_power_of_two());
86        let half = data.ndim() / 2;
87        let perm = (half..data.ndim()).chain(0..half).collect_vec();
88        let used_data = std::mem::take(data);
89        *data = used_data.permuted_axes(perm);
90    }
91}
92
93/// Helper method to compute the adjoint (conjugate transpose) of a matrix-like data
94/// tensor in-place. The data tensor can be of shape `(2^n, 2^n)`, or also be split
95/// in `2n` dimensions of size `2`, like `(2,2,2,...)`.
96///
97/// For example, both `(8,8)` or `(2,2,2,2,2,2)` are okay.
98pub(crate) fn matrix_adjoint_inplace(data: &mut DataTensor) {
99    matrix_transpose_inplace(data);
100    data.conjugate();
101}
102
103/// Checks the slice has the requested number of elements or panics.
104#[inline]
105fn unpack_angles<const N: usize>(angles: &[f64]) -> [f64; N] {
106    angles
107        .try_into()
108        .unwrap_or_else(|_| panic!("Expected {N} angles, but got {}.", angles.len()))
109}
110
111/// A quantum gate.
112pub trait Gate: Send + Sync {
113    /// Returns the name of the gate.
114    fn name(&self) -> &str;
115
116    /// Computes the gate matrix with the given angles.
117    fn compute(&self, angles: &[f64]) -> DataTensor;
118
119    /// Computes the adjoint of the gate matrix with the given angles. If not
120    /// overridden, this computes the conjugate transpose of the gate matrix.
121    fn adjoint(&self, angles: &[f64]) -> DataTensor {
122        let mut matrix = self.compute(angles);
123        matrix_adjoint_inplace(&mut matrix);
124        matrix
125    }
126}
127
128impl PartialEq for dyn Gate {
129    fn eq(&self, other: &Self) -> bool {
130        self.name() == other.name()
131    }
132}
133
134impl Eq for dyn Gate {}
135
136impl Hash for dyn Gate {
137    fn hash<H: Hasher>(&self, state: &mut H) {
138        self.name().hash(state);
139    }
140}
141
142/// This allows us to use a `&str` as a key in a `HashSet` of gates.
143impl Borrow<str> for Box<dyn Gate> {
144    fn borrow(&self) -> &str {
145        self.name()
146    }
147}
148
149/// The Pauli-X gate.
150struct X;
151impl Gate for X {
152    fn name(&self) -> &str {
153        "x"
154    }
155
156    fn compute(&self, angles: &[f64]) -> DataTensor {
157        let [] = unpack_angles(angles);
158        let z = Complex64::ZERO;
159        let o = Complex64::ONE;
160        array![[z, o], [o, z],].into_dyn()
161    }
162
163    fn adjoint(&self, angles: &[f64]) -> DataTensor {
164        // self-adjoint
165        self.compute(angles)
166    }
167}
168
169/// The Pauli-Y gate.
170struct Y;
171impl Gate for Y {
172    fn name(&self) -> &str {
173        "y"
174    }
175
176    fn compute(&self, angles: &[f64]) -> DataTensor {
177        let [] = unpack_angles(angles);
178        let z = Complex64::ZERO;
179        let i = Complex64::I;
180        array![[z, -i], [i, z]].into_dyn()
181    }
182
183    fn adjoint(&self, angles: &[f64]) -> DataTensor {
184        // self-adjoint
185        self.compute(angles)
186    }
187}
188
189/// The Pauli-Z gate.
190struct Z;
191impl Gate for Z {
192    fn name(&self) -> &str {
193        "z"
194    }
195
196    fn compute(&self, angles: &[f64]) -> DataTensor {
197        let [] = unpack_angles(angles);
198        let z = Complex64::ZERO;
199        let o = Complex64::ONE;
200        array![[o, z], [z, -o]].into_dyn()
201    }
202
203    fn adjoint(&self, angles: &[f64]) -> DataTensor {
204        // self-adjoint
205        self.compute(angles)
206    }
207}
208
209/// The Hadamard gate.
210struct H;
211impl Gate for H {
212    fn name(&self) -> &str {
213        "h"
214    }
215
216    fn compute(&self, angles: &[f64]) -> DataTensor {
217        let [] = unpack_angles(angles);
218        let h = Complex64::new(FRAC_1_SQRT_2, 0.0);
219        array![[h, h], [h, -h]].into_dyn()
220    }
221
222    fn adjoint(&self, angles: &[f64]) -> DataTensor {
223        // self-adjoint
224        self.compute(angles)
225    }
226}
227
228/// The T gate.
229struct T;
230impl Gate for T {
231    fn name(&self) -> &str {
232        "t"
233    }
234
235    fn compute(&self, angles: &[f64]) -> DataTensor {
236        let [] = unpack_angles(angles);
237        let z = Complex64::ZERO;
238        let o = Complex64::ONE;
239        let t = Complex64::new(FRAC_1_SQRT_2, FRAC_1_SQRT_2);
240        array![[o, z], [z, t]].into_dyn()
241    }
242
243    fn adjoint(&self, angles: &[f64]) -> DataTensor {
244        // symmetric
245        let mut matrix = self.compute(angles);
246        matrix.conjugate();
247        matrix
248    }
249}
250
251/// The U gate with three parameters, following the [OpenQASM 3.0 specification](https://openqasm.com/language/gates.html#built-in-gates).
252struct U;
253impl Gate for U {
254    fn name(&self) -> &str {
255        "u"
256    }
257
258    fn compute(&self, angles: &[f64]) -> DataTensor {
259        let [theta, phi, lambda] = unpack_angles(angles);
260        let (sin, cos) = (theta / 2.0).sin_cos();
261        array![
262            [
263                Complex64::new(cos, 0.0),
264                -(Complex64::I * lambda).exp() * sin
265            ],
266            [
267                (Complex64::I * phi).exp() * sin,
268                (Complex64::I * (phi + lambda)).exp() * cos
269            ]
270        ]
271        .into_dyn()
272    }
273
274    fn adjoint(&self, angles: &[f64]) -> DataTensor {
275        // This explicit implementation is ~30% faster
276        let [theta, phi, lambda] = unpack_angles(angles);
277        let (sin, cos) = (theta / 2.0).sin_cos();
278        array![
279            [Complex64::new(cos, 0.0), (Complex64::I * -phi).exp() * sin],
280            [
281                -(Complex64::I * -lambda).exp() * sin,
282                (Complex64::I * -(phi + lambda)).exp() * cos
283            ]
284        ]
285        .into_dyn()
286    }
287}
288
289/// The square-root of X gate.
290struct Sx;
291impl Gate for Sx {
292    fn name(&self) -> &str {
293        "sx"
294    }
295
296    fn compute(&self, angles: &[f64]) -> DataTensor {
297        let [] = unpack_angles(angles);
298        let a = Complex64::new(0.5, 0.5);
299        let b = Complex64::new(0.5, -0.5);
300        array![[a, b], [b, a]].into_dyn()
301    }
302
303    fn adjoint(&self, angles: &[f64]) -> DataTensor {
304        // symmetric
305        let mut matrix = self.compute(angles);
306        matrix.conjugate();
307        matrix
308    }
309}
310
311/// The square-root of Y gate.
312struct Sy;
313impl Gate for Sy {
314    fn name(&self) -> &str {
315        "sy"
316    }
317
318    fn compute(&self, angles: &[f64]) -> DataTensor {
319        let [] = unpack_angles(angles);
320        let a = Complex64::new(0.5, 0.5);
321        let b = Complex64::new(-0.5, -0.5);
322        array![[a, b], [a, a]].into_dyn()
323    }
324}
325
326/// The square-root of Z gate.
327struct Sz;
328impl Gate for Sz {
329    fn name(&self) -> &str {
330        "sz"
331    }
332
333    fn compute(&self, angles: &[f64]) -> DataTensor {
334        let [] = unpack_angles(angles);
335        let z = Complex64::ZERO;
336        let o = Complex64::ONE;
337        let i = Complex64::I;
338        array![[o, z], [z, i]].into_dyn()
339    }
340
341    fn adjoint(&self, angles: &[f64]) -> DataTensor {
342        // symmetric
343        let mut matrix = self.compute(angles);
344        matrix.conjugate();
345        matrix
346    }
347}
348
349/// Rotation by angle along X axis.
350struct Rx;
351impl Gate for Rx {
352    fn name(&self) -> &str {
353        "rx"
354    }
355
356    fn compute(&self, angles: &[f64]) -> DataTensor {
357        let [theta] = unpack_angles(angles);
358        let (sin, cos) = (theta / 2.0).sin_cos();
359        let o = Complex64::ONE;
360        let i = Complex64::I;
361        array![[o * cos, -i * sin], [-i * sin, o * cos]].into_dyn()
362    }
363
364    fn adjoint(&self, angles: &[f64]) -> DataTensor {
365        // symmetric
366        self.compute(&angles.iter().map(|&x| -x).collect_vec())
367    }
368}
369
370/// Rotation by angle along Y axis.
371struct Ry;
372impl Gate for Ry {
373    fn name(&self) -> &str {
374        "ry"
375    }
376
377    fn compute(&self, angles: &[f64]) -> DataTensor {
378        let [theta] = unpack_angles(angles);
379        let (sin, cos) = (theta / 2.0).sin_cos();
380        let o = Complex64::ONE;
381
382        array![[o * cos, -o * sin], [o * sin, o * cos]].into_dyn()
383    }
384
385    fn adjoint(&self, angles: &[f64]) -> DataTensor {
386        // symmetric
387        self.compute(&angles.iter().map(|&x| -x).collect_vec())
388    }
389}
390
391/// Rotation by angle along Z axis.
392struct Rz;
393impl Gate for Rz {
394    fn name(&self) -> &str {
395        "rz"
396    }
397
398    fn compute(&self, angles: &[f64]) -> DataTensor {
399        let [theta] = unpack_angles(angles);
400        let z = Complex64::ZERO;
401        let i = Complex64::I;
402
403        array![[(-i * theta / 2.0).exp(), z], [z, (i * theta / 2.0).exp()]].into_dyn()
404    }
405
406    fn adjoint(&self, angles: &[f64]) -> DataTensor {
407        // symmetric
408        self.compute(&angles.iter().map(|&x| -x).collect_vec())
409    }
410}
411
412/// The controlled-X gate.
413struct Cx;
414impl Gate for Cx {
415    fn name(&self) -> &str {
416        "cx"
417    }
418
419    fn compute(&self, angles: &[f64]) -> DataTensor {
420        let [] = unpack_angles(angles);
421        let z = Complex64::ZERO;
422        let o = Complex64::ONE;
423        array![[o, z, z, z], [z, o, z, z], [z, z, z, o], [z, z, o, z]]
424            .into_shape_with_order((2, 2, 2, 2))
425            .unwrap()
426            .into_dyn()
427    }
428
429    fn adjoint(&self, angles: &[f64]) -> DataTensor {
430        // self-adjoint
431        self.compute(angles)
432    }
433}
434
435/// The controlled-Z gate.
436struct Cz;
437impl Gate for Cz {
438    fn name(&self) -> &str {
439        "cz"
440    }
441
442    fn compute(&self, angles: &[f64]) -> DataTensor {
443        let [] = unpack_angles(angles);
444        let z = Complex64::ZERO;
445        let o = Complex64::ONE;
446        array![[o, z, z, z], [z, o, z, z], [z, z, o, z], [z, z, z, -o]]
447            .into_shape_with_order((2, 2, 2, 2))
448            .unwrap()
449            .into_dyn()
450    }
451
452    fn adjoint(&self, angles: &[f64]) -> DataTensor {
453        // self-adjoint
454        self.compute(angles)
455    }
456}
457
458/// The SWAP gate.
459struct Swap;
460impl Gate for Swap {
461    fn name(&self) -> &str {
462        "swap"
463    }
464
465    fn compute(&self, angles: &[f64]) -> DataTensor {
466        let [] = unpack_angles(angles);
467        let z = Complex64::ZERO;
468        let o = Complex64::ONE;
469        array![[o, z, z, z], [z, z, o, z], [z, o, z, z], [z, z, z, o]]
470            .into_shape_with_order((2, 2, 2, 2))
471            .unwrap()
472            .into_dyn()
473    }
474
475    fn adjoint(&self, angles: &[f64]) -> DataTensor {
476        // self-adjoint
477        self.compute(angles)
478    }
479}
480
481/// The controlled Phase gate.
482struct Cp;
483impl Gate for Cp {
484    fn name(&self) -> &str {
485        "cp"
486    }
487
488    fn compute(&self, angles: &[f64]) -> DataTensor {
489        let [theta] = unpack_angles(angles);
490        let z = Complex64::ZERO;
491        let o = Complex64::ONE;
492        let e = (Complex64::I * theta).exp();
493        array![[o, z, z, z], [z, o, z, z], [z, z, o, z], [z, z, z, e]]
494            .into_shape_with_order((2, 2, 2, 2))
495            .unwrap()
496            .into_dyn()
497    }
498
499    fn adjoint(&self, angles: &[f64]) -> DataTensor {
500        // symmetric
501        self.compute(&angles.iter().map(|&x| -x).collect_vec())
502    }
503}
504
505/// The iSWAP gate.
506struct Iswap;
507impl Gate for Iswap {
508    fn name(&self) -> &str {
509        "iswap"
510    }
511
512    fn compute(&self, angles: &[f64]) -> DataTensor {
513        let [] = unpack_angles(angles);
514        let z = Complex64::ZERO;
515        let o = Complex64::ONE;
516        let i = Complex64::I;
517        array![[o, z, z, z], [z, z, i, z], [z, i, z, z], [z, z, z, o]]
518            .into_shape_with_order((2, 2, 2, 2))
519            .unwrap()
520            .into_dyn()
521    }
522
523    fn adjoint(&self, angles: &[f64]) -> DataTensor {
524        // symmetric
525        let mut matrix = self.compute(angles);
526        matrix.conjugate();
527        matrix
528    }
529}
530
531/// The FSIM gate, as described e.g. [here](https://quantumai.google/reference/python/cirq/FSimGate).
532struct Fsim;
533impl Gate for Fsim {
534    fn name(&self) -> &str {
535        "fsim"
536    }
537
538    fn compute(&self, angles: &[f64]) -> DataTensor {
539        let [theta, phi] = unpack_angles(angles);
540        let z = Complex64::ZERO;
541        let o = Complex64::ONE;
542        let a = Complex64::new(theta.cos(), 0.0);
543        let b = Complex64::new(0.0, -theta.sin());
544        let c = Complex64::new(0.0, -phi).exp();
545        array![[o, z, z, z], [z, a, b, z], [z, b, a, z], [z, z, z, c]]
546            .into_shape_with_order((2, 2, 2, 2))
547            .unwrap()
548            .into_dyn()
549    }
550
551    fn adjoint(&self, angles: &[f64]) -> DataTensor {
552        // symmetric
553        self.compute(&angles.iter().map(|&x| -x).collect_vec())
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use std::f64::consts::PI;
560
561    use approx::assert_abs_diff_eq;
562    use rand::{distr::Uniform, prelude::Distribution, rngs::StdRng, SeedableRng};
563    use rustc_hash::FxHashMap;
564
565    use super::*;
566
567    #[test]
568    #[should_panic(expected = "Gate 'foo' not found.")]
569    fn load_unknown() {
570        let _ = load_gate("foo", &[]);
571    }
572
573    #[test]
574    #[should_panic(expected = "Gate 'foo' not found.")]
575    fn load_unknown_adjoint() {
576        let _ = load_gate_adjoint("foo", &[]);
577    }
578
579    #[test]
580    #[should_panic(expected = "Expected 0 angles, but got 2.")]
581    fn too_many_angles() {
582        let [] = unpack_angles(&[2.0, 4.0]);
583    }
584
585    #[test]
586    fn test_custom_adjoint_impls() {
587        let gate_params = FxHashMap::from_iter([
588            ("u", 3),
589            ("rx", 1),
590            ("ry", 1),
591            ("rz", 1),
592            ("cp", 1),
593            ("fsim", 2),
594        ]);
595        let rng = StdRng::seed_from_u64(42);
596        let dist = Uniform::new(-PI, PI).unwrap();
597        let rng_iter = &mut dist.sample_iter(rng);
598
599        for gate in GATES.read().unwrap().iter() {
600            let param_count = gate_params.get(gate.name()).copied().unwrap_or_default();
601            let params = rng_iter.take(param_count).collect_vec();
602            let specialized_adjoint = gate.adjoint(&params);
603            let mut matrix = gate.compute(&params);
604            matrix_adjoint_inplace(&mut matrix);
605            let general_adjoint = matrix;
606            assert_abs_diff_eq!(&specialized_adjoint, &general_adjoint);
607        }
608    }
609}