tnc/
gates.rs

1//! Implementation of various quantum gates to be used in tensor networks.
2
3use std::{
4    borrow::Borrow,
5    f64::consts::FRAC_1_SQRT_2,
6    hash::{Hash, Hasher},
7    sync::{LazyLock, RwLock},
8};
9
10use itertools::Itertools;
11use num_complex::Complex64;
12use permutation::Permutation;
13use rustc_hash::FxHashSet;
14use tetra::Tensor as DataTensor;
15
16static GATES: LazyLock<RwLock<FxHashSet<Box<dyn Gate>>>> = LazyLock::new(|| {
17    let mut gates = FxHashSet::default();
18    gates.insert(Box::new(X) as _);
19    gates.insert(Box::new(Y) as _);
20    gates.insert(Box::new(Z) as _);
21    gates.insert(Box::new(H) as _);
22    gates.insert(Box::new(T) as _);
23    gates.insert(Box::new(U) as _);
24    gates.insert(Box::new(Sx) as _);
25    gates.insert(Box::new(Sy) as _);
26    gates.insert(Box::new(Sz) as _);
27    gates.insert(Box::new(Rx) as _);
28    gates.insert(Box::new(Ry) as _);
29    gates.insert(Box::new(Rz) as _);
30    gates.insert(Box::new(Cx) as _);
31    gates.insert(Box::new(Cz) as _);
32    gates.insert(Box::new(Cp) as _);
33    gates.insert(Box::new(Iswap) as _);
34    gates.insert(Box::new(Fsim) as _);
35    RwLock::new(gates)
36});
37
38/// Registers a gate definition to resolve a gate name to a gate implementation.
39pub fn register_gate(gate: Box<dyn Gate>) {
40    assert!(
41        gate.name().to_ascii_lowercase() == gate.name(),
42        "Gate name must be lowercase."
43    );
44    GATES.write().unwrap().insert(gate);
45}
46
47/// Computes the gate matrix for the given gate and angles.
48#[must_use]
49pub fn load_gate(gate: &str, angles: &[f64]) -> DataTensor {
50    let gates = &GATES.read().unwrap();
51    let gate = gates
52        .get(gate)
53        .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
54    gate.compute(angles)
55}
56
57/// Computes the adjoint of the gate matrix for the given gate and angles.
58#[must_use]
59pub fn load_gate_adjoint(gate: &str, angles: &[f64]) -> DataTensor {
60    let gates = &GATES.read().unwrap();
61    let gate = gates
62        .get(gate)
63        .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
64    gate.adjoint(angles)
65}
66
67/// Returns whether the given gate is known.
68#[must_use]
69pub fn is_gate_known(gate: &str) -> bool {
70    let gates = &GATES.read().unwrap();
71    gates.contains(gate)
72}
73
74/// Helper method to compute the transpose of a matrix-like data tensor in-place. The
75/// data tensor can be of shape `(2^n, 2^n)`, or also be split in `2n` dimensions of
76/// size `2`, like `(2,2,2,...)`. In the second case, the transpose is computed by
77/// swapping the first half of the dimensions with the second half.
78///
79/// For example, both `(8,8)` or `(2,2,2,2,2,2)` are okay. If given `(2,2,2,2,2,2)`,
80/// the permutation applied will be `(3,4,5,0,1,2)`.
81fn matrix_transpose_inplace(data: &mut DataTensor) {
82    if data.ndim() > 0 {
83        assert!(data.ndim().is_power_of_two());
84        let half = data.ndim() / 2;
85        let perm = (half..data.ndim()).chain(0..half).collect_vec();
86        data.transpose(&Permutation::oneline(perm));
87    }
88}
89
90/// Helper method to compute the adjoint (conjugate transpose) of a matrix-like data
91/// tensor in-place. The data tensor can be of shape `(2^n, 2^n)`, or also be split
92/// in `2n` dimensions of size `2`, like `(2,2,2,...)`.
93///
94/// For example, both `(8,8)` or `(2,2,2,2,2,2)` are okay.
95pub(crate) fn matrix_adjoint_inplace(data: &mut DataTensor) {
96    matrix_transpose_inplace(data);
97    data.conjugate();
98}
99
100/// Checks the slice has the requested number of elements or panics.
101#[inline]
102fn unpack_angles<const N: usize>(angles: &[f64]) -> [f64; N] {
103    angles
104        .try_into()
105        .unwrap_or_else(|_| panic!("Expected {N} angles, but got {}.", angles.len()))
106}
107
108/// A quantum gate.
109pub trait Gate: Send + Sync {
110    /// Returns the name of the gate.
111    fn name(&self) -> &str;
112
113    /// Computes the gate matrix with the given angles.
114    fn compute(&self, angles: &[f64]) -> DataTensor;
115
116    /// Computes the adjoint of the gate matrix with the given angles. If not
117    /// overridden, this computes the conjugate transpose of the gate matrix.
118    fn adjoint(&self, angles: &[f64]) -> DataTensor {
119        let mut matrix = self.compute(angles);
120        matrix_adjoint_inplace(&mut matrix);
121        matrix
122    }
123}
124
125impl PartialEq for dyn Gate {
126    fn eq(&self, other: &Self) -> bool {
127        self.name() == other.name()
128    }
129}
130
131impl Eq for dyn Gate {}
132
133impl Hash for dyn Gate {
134    fn hash<H: Hasher>(&self, state: &mut H) {
135        self.name().hash(state);
136    }
137}
138
139/// This allows us to use a `&str` as a key in a `HashSet` of gates.
140impl Borrow<str> for Box<dyn Gate> {
141    fn borrow(&self) -> &str {
142        self.name()
143    }
144}
145
146/// The Pauli-X gate.
147struct X;
148impl Gate for X {
149    fn name(&self) -> &str {
150        "x"
151    }
152
153    fn compute(&self, angles: &[f64]) -> DataTensor {
154        let [] = unpack_angles(angles);
155        let z = Complex64::ZERO;
156        let o = Complex64::ONE;
157        #[rustfmt::skip]
158        let data = vec![
159            z, o,
160            o, z,
161        ];
162        DataTensor::new_from_flat(&[2, 2], data, None)
163    }
164
165    fn adjoint(&self, angles: &[f64]) -> DataTensor {
166        // self-adjoint
167        self.compute(angles)
168    }
169}
170
171/// The Pauli-Y gate.
172struct Y;
173impl Gate for Y {
174    fn name(&self) -> &str {
175        "y"
176    }
177
178    fn compute(&self, angles: &[f64]) -> DataTensor {
179        let [] = unpack_angles(angles);
180        let z = Complex64::ZERO;
181        let i = Complex64::I;
182        #[rustfmt::skip]
183        let data = vec![
184            z, -i,
185            i,  z,
186        ];
187        DataTensor::new_from_flat(&[2, 2], data, None)
188    }
189
190    fn adjoint(&self, angles: &[f64]) -> DataTensor {
191        // self-adjoint
192        self.compute(angles)
193    }
194}
195
196/// The Pauli-Z gate.
197struct Z;
198impl Gate for Z {
199    fn name(&self) -> &str {
200        "z"
201    }
202
203    fn compute(&self, angles: &[f64]) -> DataTensor {
204        let [] = unpack_angles(angles);
205        let z = Complex64::ZERO;
206        let o = Complex64::ONE;
207        #[rustfmt::skip]
208        let data = vec![
209            o,  z,
210            z, -o,
211        ];
212        DataTensor::new_from_flat(&[2, 2], data, None)
213    }
214
215    fn adjoint(&self, angles: &[f64]) -> DataTensor {
216        // self-adjoint
217        self.compute(angles)
218    }
219}
220
221/// The Hadamard gate.
222struct H;
223impl Gate for H {
224    fn name(&self) -> &str {
225        "h"
226    }
227
228    fn compute(&self, angles: &[f64]) -> DataTensor {
229        let [] = unpack_angles(angles);
230        let h = Complex64::new(FRAC_1_SQRT_2, 0.0);
231        #[rustfmt::skip]
232        let data = vec![
233            h,  h,
234            h, -h,
235        ];
236        DataTensor::new_from_flat(&[2, 2], data, None)
237    }
238
239    fn adjoint(&self, angles: &[f64]) -> DataTensor {
240        // self-adjoint
241        self.compute(angles)
242    }
243}
244
245/// The T gate.
246struct T;
247impl Gate for T {
248    fn name(&self) -> &str {
249        "t"
250    }
251
252    fn compute(&self, angles: &[f64]) -> DataTensor {
253        let [] = unpack_angles(angles);
254        let z = Complex64::ZERO;
255        let o = Complex64::ONE;
256        let t = Complex64::new(FRAC_1_SQRT_2, FRAC_1_SQRT_2);
257        #[rustfmt::skip]
258        let data = vec![
259            o, z,
260            z, t,
261        ];
262        DataTensor::new_from_flat(&[2, 2], data, None)
263    }
264
265    fn adjoint(&self, angles: &[f64]) -> DataTensor {
266        // symmetric
267        let mut matrix = self.compute(angles);
268        matrix.conjugate();
269        matrix
270    }
271}
272
273/// The U gate with three parameters, following the [OpenQASM 3.0 specification](https://openqasm.com/language/gates.html#built-in-gates).
274struct U;
275impl Gate for U {
276    fn name(&self) -> &str {
277        "u"
278    }
279
280    fn compute(&self, angles: &[f64]) -> DataTensor {
281        let [theta, phi, lambda] = unpack_angles(angles);
282        let (sin, cos) = (theta / 2.0).sin_cos();
283        let data = vec![
284            Complex64::new(cos, 0.0),
285            -(Complex64::I * lambda).exp() * sin,
286            (Complex64::I * phi).exp() * sin,
287            (Complex64::I * (phi + lambda)).exp() * cos,
288        ];
289        DataTensor::new_from_flat(&[2, 2], data, None)
290    }
291
292    fn adjoint(&self, angles: &[f64]) -> DataTensor {
293        // This explicit implementation is ~30% faster
294        let [theta, phi, lambda] = unpack_angles(angles);
295        let (sin, cos) = (theta / 2.0).sin_cos();
296        let data = vec![
297            Complex64::new(cos, 0.0),
298            (Complex64::I * -phi).exp() * sin,
299            -(Complex64::I * -lambda).exp() * sin,
300            (Complex64::I * -(phi + lambda)).exp() * cos,
301        ];
302        DataTensor::new_from_flat(&[2, 2], data, None)
303    }
304}
305
306/// The square-root of X gate.
307struct Sx;
308impl Gate for Sx {
309    fn name(&self) -> &str {
310        "sx"
311    }
312
313    fn compute(&self, angles: &[f64]) -> DataTensor {
314        let [] = unpack_angles(angles);
315        let a = Complex64::new(0.5, 0.5);
316        let b = Complex64::new(0.5, -0.5);
317        #[rustfmt::skip]
318        let data = vec![
319            a, b,
320            b, a,
321        ];
322        DataTensor::new_from_flat(&[2, 2], data, None)
323    }
324
325    fn adjoint(&self, angles: &[f64]) -> DataTensor {
326        // symmetric
327        let mut matrix = self.compute(angles);
328        matrix.conjugate();
329        matrix
330    }
331}
332
333/// The square-root of Y gate.
334struct Sy;
335impl Gate for Sy {
336    fn name(&self) -> &str {
337        "sy"
338    }
339
340    fn compute(&self, angles: &[f64]) -> DataTensor {
341        let [] = unpack_angles(angles);
342        let a = Complex64::new(0.5, 0.5);
343        let b = Complex64::new(-0.5, -0.5);
344        #[rustfmt::skip]
345        let data = vec![
346            a, b,
347            a, a,
348        ];
349        DataTensor::new_from_flat(&[2, 2], data, None)
350    }
351}
352
353/// The square-root of Z gate.
354struct Sz;
355impl Gate for Sz {
356    fn name(&self) -> &str {
357        "sz"
358    }
359
360    fn compute(&self, angles: &[f64]) -> DataTensor {
361        let [] = unpack_angles(angles);
362        let z = Complex64::ZERO;
363        let o = Complex64::ONE;
364        let i = Complex64::I;
365        #[rustfmt::skip]
366        let data = vec![
367            o, z,
368            z, i,
369        ];
370        DataTensor::new_from_flat(&[2, 2], data, None)
371    }
372
373    fn adjoint(&self, angles: &[f64]) -> DataTensor {
374        // symmetric
375        let mut matrix = self.compute(angles);
376        matrix.conjugate();
377        matrix
378    }
379}
380
381/// Rotation by angle along X axis.
382struct Rx;
383impl Gate for Rx {
384    fn name(&self) -> &str {
385        "rx"
386    }
387
388    fn compute(&self, angles: &[f64]) -> DataTensor {
389        let [theta] = unpack_angles(angles);
390        let (sin, cos) = (theta / 2.0).sin_cos();
391        let o = Complex64::ONE;
392        let i = Complex64::I;
393        #[rustfmt::skip]
394        let data = vec![
395            o*cos, -i*sin,
396            -i*sin, o*cos,
397        ];
398        DataTensor::new_from_flat(&[2, 2], data, None)
399    }
400
401    fn adjoint(&self, angles: &[f64]) -> DataTensor {
402        // symmetric
403        self.compute(&angles.iter().map(|&x| -x).collect_vec())
404    }
405}
406
407/// Rotation by angle along Y axis.
408struct Ry;
409impl Gate for Ry {
410    fn name(&self) -> &str {
411        "ry"
412    }
413
414    fn compute(&self, angles: &[f64]) -> DataTensor {
415        let [theta] = unpack_angles(angles);
416        let (sin, cos) = (theta / 2.0).sin_cos();
417        let o = Complex64::ONE;
418
419        #[rustfmt::skip]
420        let data = vec![
421            o*cos, -o*sin,
422            o*sin, o*cos,
423        ];
424        DataTensor::new_from_flat(&[2, 2], data, None)
425    }
426
427    fn adjoint(&self, angles: &[f64]) -> DataTensor {
428        // symmetric
429        self.compute(&angles.iter().map(|&x| -x).collect_vec())
430    }
431}
432
433/// Rotation by angle along Z axis.
434struct Rz;
435impl Gate for Rz {
436    fn name(&self) -> &str {
437        "rz"
438    }
439
440    fn compute(&self, angles: &[f64]) -> DataTensor {
441        let [theta] = unpack_angles(angles);
442        let z = Complex64::ZERO;
443        let i = Complex64::I;
444
445        #[rustfmt::skip]
446        let data = vec![
447            (-i*theta/2.0).exp(), z,
448            z, (i*theta/2.0).exp(),
449        ];
450        DataTensor::new_from_flat(&[2, 2], data, None)
451    }
452
453    fn adjoint(&self, angles: &[f64]) -> DataTensor {
454        // symmetric
455        self.compute(&angles.iter().map(|&x| -x).collect_vec())
456    }
457}
458
459/// The controlled-X gate.
460struct Cx;
461impl Gate for Cx {
462    fn name(&self) -> &str {
463        "cx"
464    }
465
466    fn compute(&self, angles: &[f64]) -> DataTensor {
467        let [] = unpack_angles(angles);
468        let z = Complex64::ZERO;
469        let o = Complex64::ONE;
470        #[rustfmt::skip]
471        let data = vec![
472            o, z, z, z,
473            z, o, z, z,
474            z, z, z, o,
475            z, z, o, z,
476        ];
477        DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
478    }
479
480    fn adjoint(&self, angles: &[f64]) -> DataTensor {
481        // self-adjoint
482        self.compute(angles)
483    }
484}
485
486/// The controlled-Z gate.
487struct Cz;
488impl Gate for Cz {
489    fn name(&self) -> &str {
490        "cz"
491    }
492
493    fn compute(&self, angles: &[f64]) -> DataTensor {
494        let [] = unpack_angles(angles);
495        let z = Complex64::ZERO;
496        let o = Complex64::ONE;
497        #[rustfmt::skip]
498        let data = vec![
499            o, z, z, z,
500            z, o, z, z,
501            z, z, o, z,
502            z, z, z, -o,
503        ];
504        DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
505    }
506
507    fn adjoint(&self, angles: &[f64]) -> DataTensor {
508        // self-adjoint
509        self.compute(angles)
510    }
511}
512
513/// The controlled Phase gate.
514struct Cp;
515impl Gate for Cp {
516    fn name(&self) -> &str {
517        "cp"
518    }
519
520    fn compute(&self, angles: &[f64]) -> DataTensor {
521        let [theta] = unpack_angles(angles);
522        let z = Complex64::ZERO;
523        let o = Complex64::ONE;
524        let e = (Complex64::I * theta).exp();
525        #[rustfmt::skip]
526        let data = vec![
527            o, z, z, z,
528            z, o, z, z,
529            z, z, o, z,
530            z, z, z, e,
531        ];
532        DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
533    }
534
535    fn adjoint(&self, angles: &[f64]) -> DataTensor {
536        // symmetric
537        self.compute(&angles.iter().map(|&x| -x).collect_vec())
538    }
539}
540
541/// The iSWAP gate.
542struct Iswap;
543impl Gate for Iswap {
544    fn name(&self) -> &str {
545        "iswap"
546    }
547
548    fn compute(&self, angles: &[f64]) -> DataTensor {
549        let [] = unpack_angles(angles);
550        let z = Complex64::ZERO;
551        let o = Complex64::ONE;
552        let i = Complex64::I;
553        #[rustfmt::skip]
554        let data = vec![
555            o, z, z, z,
556            z, z, i, z,
557            z, i, z, z,
558            z, z, z, o,
559        ];
560        DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
561    }
562
563    fn adjoint(&self, angles: &[f64]) -> DataTensor {
564        // symmetric
565        let mut matrix = self.compute(angles);
566        matrix.conjugate();
567        matrix
568    }
569}
570
571/// The FSIM gate, as described e.g. [here](https://quantumai.google/reference/python/cirq/FSimGate).
572struct Fsim;
573impl Gate for Fsim {
574    fn name(&self) -> &str {
575        "fsim"
576    }
577
578    fn compute(&self, angles: &[f64]) -> DataTensor {
579        let [theta, phi] = unpack_angles(angles);
580        let z = Complex64::ZERO;
581        let o = Complex64::ONE;
582        let a = Complex64::new(theta.cos(), 0.0);
583        let b = Complex64::new(0.0, -theta.sin());
584        let c = Complex64::new(0.0, -phi).exp();
585        #[rustfmt::skip]
586        let data = vec![
587            o, z, z, z,
588            z, a, b, z,
589            z, b, a, z,
590            z, z, z, c,
591        ];
592        DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
593    }
594
595    fn adjoint(&self, angles: &[f64]) -> DataTensor {
596        // symmetric
597        self.compute(&angles.iter().map(|&x| -x).collect_vec())
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use std::f64::consts::PI;
604
605    use float_cmp::assert_approx_eq;
606    use rand::{distr::Uniform, prelude::Distribution, rngs::StdRng, SeedableRng};
607    use rustc_hash::FxHashMap;
608
609    use super::*;
610
611    #[test]
612    #[should_panic(expected = "Gate 'foo' not found.")]
613    fn load_unknown() {
614        let _ = load_gate("foo", &[]);
615    }
616
617    #[test]
618    #[should_panic(expected = "Gate 'foo' not found.")]
619    fn load_unknown_adjoint() {
620        let _ = load_gate_adjoint("foo", &[]);
621    }
622
623    #[test]
624    #[should_panic(expected = "Expected 0 angles, but got 2.")]
625    fn too_many_angles() {
626        let [] = unpack_angles(&[2.0, 4.0]);
627    }
628
629    #[test]
630    fn test_custom_adjoint_impls() {
631        let gate_params = FxHashMap::from_iter([
632            ("u", 3),
633            ("rx", 1),
634            ("ry", 1),
635            ("rz", 1),
636            ("cp", 1),
637            ("fsim", 2),
638        ]);
639        let rng = StdRng::seed_from_u64(42);
640        let dist = Uniform::new(-PI, PI).unwrap();
641        let rng_iter = &mut dist.sample_iter(rng);
642
643        for gate in GATES.read().unwrap().iter() {
644            let param_count = gate_params.get(gate.name()).copied().unwrap_or_default();
645            let params = rng_iter.take(param_count).collect_vec();
646            let specialized_adjoint = gate.adjoint(&params);
647            let mut matrix = gate.compute(&params);
648            matrix_adjoint_inplace(&mut matrix);
649            let general_adjoint = matrix;
650            assert_approx_eq!(&DataTensor, &specialized_adjoint, &general_adjoint);
651        }
652    }
653}