1use std::{
4 borrow::Borrow,
5 f64::consts::FRAC_1_SQRT_2,
6 hash::{Hash, Hasher},
7 sync::{LazyLock, RwLock},
8};
9
10use itertools::Itertools;
11use num_complex::Complex64;
12use permutation::Permutation;
13use rustc_hash::FxHashSet;
14use tetra::Tensor as DataTensor;
15
16static GATES: LazyLock<RwLock<FxHashSet<Box<dyn Gate>>>> = LazyLock::new(|| {
17 let mut gates = FxHashSet::default();
18 gates.insert(Box::new(X) as _);
19 gates.insert(Box::new(Y) as _);
20 gates.insert(Box::new(Z) as _);
21 gates.insert(Box::new(H) as _);
22 gates.insert(Box::new(T) as _);
23 gates.insert(Box::new(U) as _);
24 gates.insert(Box::new(Sx) as _);
25 gates.insert(Box::new(Sy) as _);
26 gates.insert(Box::new(Sz) as _);
27 gates.insert(Box::new(Rx) as _);
28 gates.insert(Box::new(Ry) as _);
29 gates.insert(Box::new(Rz) as _);
30 gates.insert(Box::new(Cx) as _);
31 gates.insert(Box::new(Cz) as _);
32 gates.insert(Box::new(Cp) as _);
33 gates.insert(Box::new(Iswap) as _);
34 gates.insert(Box::new(Fsim) as _);
35 RwLock::new(gates)
36});
37
38pub fn register_gate(gate: Box<dyn Gate>) {
40 assert!(
41 gate.name().to_ascii_lowercase() == gate.name(),
42 "Gate name must be lowercase."
43 );
44 GATES.write().unwrap().insert(gate);
45}
46
47#[must_use]
49pub fn load_gate(gate: &str, angles: &[f64]) -> DataTensor {
50 let gates = &GATES.read().unwrap();
51 let gate = gates
52 .get(gate)
53 .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
54 gate.compute(angles)
55}
56
57#[must_use]
59pub fn load_gate_adjoint(gate: &str, angles: &[f64]) -> DataTensor {
60 let gates = &GATES.read().unwrap();
61 let gate = gates
62 .get(gate)
63 .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
64 gate.adjoint(angles)
65}
66
67#[must_use]
69pub fn is_gate_known(gate: &str) -> bool {
70 let gates = &GATES.read().unwrap();
71 gates.contains(gate)
72}
73
74fn matrix_transpose_inplace(data: &mut DataTensor) {
82 if data.ndim() > 0 {
83 assert!(data.ndim().is_power_of_two());
84 let half = data.ndim() / 2;
85 let perm = (half..data.ndim()).chain(0..half).collect_vec();
86 data.transpose(&Permutation::oneline(perm));
87 }
88}
89
90pub(crate) fn matrix_adjoint_inplace(data: &mut DataTensor) {
96 matrix_transpose_inplace(data);
97 data.conjugate();
98}
99
100#[inline]
102fn unpack_angles<const N: usize>(angles: &[f64]) -> [f64; N] {
103 angles
104 .try_into()
105 .unwrap_or_else(|_| panic!("Expected {N} angles, but got {}.", angles.len()))
106}
107
108pub trait Gate: Send + Sync {
110 fn name(&self) -> &str;
112
113 fn compute(&self, angles: &[f64]) -> DataTensor;
115
116 fn adjoint(&self, angles: &[f64]) -> DataTensor {
119 let mut matrix = self.compute(angles);
120 matrix_adjoint_inplace(&mut matrix);
121 matrix
122 }
123}
124
125impl PartialEq for dyn Gate {
126 fn eq(&self, other: &Self) -> bool {
127 self.name() == other.name()
128 }
129}
130
131impl Eq for dyn Gate {}
132
133impl Hash for dyn Gate {
134 fn hash<H: Hasher>(&self, state: &mut H) {
135 self.name().hash(state);
136 }
137}
138
139impl Borrow<str> for Box<dyn Gate> {
141 fn borrow(&self) -> &str {
142 self.name()
143 }
144}
145
146struct X;
148impl Gate for X {
149 fn name(&self) -> &str {
150 "x"
151 }
152
153 fn compute(&self, angles: &[f64]) -> DataTensor {
154 let [] = unpack_angles(angles);
155 let z = Complex64::ZERO;
156 let o = Complex64::ONE;
157 #[rustfmt::skip]
158 let data = vec![
159 z, o,
160 o, z,
161 ];
162 DataTensor::new_from_flat(&[2, 2], data, None)
163 }
164
165 fn adjoint(&self, angles: &[f64]) -> DataTensor {
166 self.compute(angles)
168 }
169}
170
171struct Y;
173impl Gate for Y {
174 fn name(&self) -> &str {
175 "y"
176 }
177
178 fn compute(&self, angles: &[f64]) -> DataTensor {
179 let [] = unpack_angles(angles);
180 let z = Complex64::ZERO;
181 let i = Complex64::I;
182 #[rustfmt::skip]
183 let data = vec![
184 z, -i,
185 i, z,
186 ];
187 DataTensor::new_from_flat(&[2, 2], data, None)
188 }
189
190 fn adjoint(&self, angles: &[f64]) -> DataTensor {
191 self.compute(angles)
193 }
194}
195
196struct Z;
198impl Gate for Z {
199 fn name(&self) -> &str {
200 "z"
201 }
202
203 fn compute(&self, angles: &[f64]) -> DataTensor {
204 let [] = unpack_angles(angles);
205 let z = Complex64::ZERO;
206 let o = Complex64::ONE;
207 #[rustfmt::skip]
208 let data = vec![
209 o, z,
210 z, -o,
211 ];
212 DataTensor::new_from_flat(&[2, 2], data, None)
213 }
214
215 fn adjoint(&self, angles: &[f64]) -> DataTensor {
216 self.compute(angles)
218 }
219}
220
221struct H;
223impl Gate for H {
224 fn name(&self) -> &str {
225 "h"
226 }
227
228 fn compute(&self, angles: &[f64]) -> DataTensor {
229 let [] = unpack_angles(angles);
230 let h = Complex64::new(FRAC_1_SQRT_2, 0.0);
231 #[rustfmt::skip]
232 let data = vec![
233 h, h,
234 h, -h,
235 ];
236 DataTensor::new_from_flat(&[2, 2], data, None)
237 }
238
239 fn adjoint(&self, angles: &[f64]) -> DataTensor {
240 self.compute(angles)
242 }
243}
244
245struct T;
247impl Gate for T {
248 fn name(&self) -> &str {
249 "t"
250 }
251
252 fn compute(&self, angles: &[f64]) -> DataTensor {
253 let [] = unpack_angles(angles);
254 let z = Complex64::ZERO;
255 let o = Complex64::ONE;
256 let t = Complex64::new(FRAC_1_SQRT_2, FRAC_1_SQRT_2);
257 #[rustfmt::skip]
258 let data = vec![
259 o, z,
260 z, t,
261 ];
262 DataTensor::new_from_flat(&[2, 2], data, None)
263 }
264
265 fn adjoint(&self, angles: &[f64]) -> DataTensor {
266 let mut matrix = self.compute(angles);
268 matrix.conjugate();
269 matrix
270 }
271}
272
273struct U;
275impl Gate for U {
276 fn name(&self) -> &str {
277 "u"
278 }
279
280 fn compute(&self, angles: &[f64]) -> DataTensor {
281 let [theta, phi, lambda] = unpack_angles(angles);
282 let (sin, cos) = (theta / 2.0).sin_cos();
283 let data = vec![
284 Complex64::new(cos, 0.0),
285 -(Complex64::I * lambda).exp() * sin,
286 (Complex64::I * phi).exp() * sin,
287 (Complex64::I * (phi + lambda)).exp() * cos,
288 ];
289 DataTensor::new_from_flat(&[2, 2], data, None)
290 }
291
292 fn adjoint(&self, angles: &[f64]) -> DataTensor {
293 let [theta, phi, lambda] = unpack_angles(angles);
295 let (sin, cos) = (theta / 2.0).sin_cos();
296 let data = vec![
297 Complex64::new(cos, 0.0),
298 (Complex64::I * -phi).exp() * sin,
299 -(Complex64::I * -lambda).exp() * sin,
300 (Complex64::I * -(phi + lambda)).exp() * cos,
301 ];
302 DataTensor::new_from_flat(&[2, 2], data, None)
303 }
304}
305
306struct Sx;
308impl Gate for Sx {
309 fn name(&self) -> &str {
310 "sx"
311 }
312
313 fn compute(&self, angles: &[f64]) -> DataTensor {
314 let [] = unpack_angles(angles);
315 let a = Complex64::new(0.5, 0.5);
316 let b = Complex64::new(0.5, -0.5);
317 #[rustfmt::skip]
318 let data = vec![
319 a, b,
320 b, a,
321 ];
322 DataTensor::new_from_flat(&[2, 2], data, None)
323 }
324
325 fn adjoint(&self, angles: &[f64]) -> DataTensor {
326 let mut matrix = self.compute(angles);
328 matrix.conjugate();
329 matrix
330 }
331}
332
333struct Sy;
335impl Gate for Sy {
336 fn name(&self) -> &str {
337 "sy"
338 }
339
340 fn compute(&self, angles: &[f64]) -> DataTensor {
341 let [] = unpack_angles(angles);
342 let a = Complex64::new(0.5, 0.5);
343 let b = Complex64::new(-0.5, -0.5);
344 #[rustfmt::skip]
345 let data = vec![
346 a, b,
347 a, a,
348 ];
349 DataTensor::new_from_flat(&[2, 2], data, None)
350 }
351}
352
353struct Sz;
355impl Gate for Sz {
356 fn name(&self) -> &str {
357 "sz"
358 }
359
360 fn compute(&self, angles: &[f64]) -> DataTensor {
361 let [] = unpack_angles(angles);
362 let z = Complex64::ZERO;
363 let o = Complex64::ONE;
364 let i = Complex64::I;
365 #[rustfmt::skip]
366 let data = vec![
367 o, z,
368 z, i,
369 ];
370 DataTensor::new_from_flat(&[2, 2], data, None)
371 }
372
373 fn adjoint(&self, angles: &[f64]) -> DataTensor {
374 let mut matrix = self.compute(angles);
376 matrix.conjugate();
377 matrix
378 }
379}
380
381struct Rx;
383impl Gate for Rx {
384 fn name(&self) -> &str {
385 "rx"
386 }
387
388 fn compute(&self, angles: &[f64]) -> DataTensor {
389 let [theta] = unpack_angles(angles);
390 let (sin, cos) = (theta / 2.0).sin_cos();
391 let o = Complex64::ONE;
392 let i = Complex64::I;
393 #[rustfmt::skip]
394 let data = vec![
395 o*cos, -i*sin,
396 -i*sin, o*cos,
397 ];
398 DataTensor::new_from_flat(&[2, 2], data, None)
399 }
400
401 fn adjoint(&self, angles: &[f64]) -> DataTensor {
402 self.compute(&angles.iter().map(|&x| -x).collect_vec())
404 }
405}
406
407struct Ry;
409impl Gate for Ry {
410 fn name(&self) -> &str {
411 "ry"
412 }
413
414 fn compute(&self, angles: &[f64]) -> DataTensor {
415 let [theta] = unpack_angles(angles);
416 let (sin, cos) = (theta / 2.0).sin_cos();
417 let o = Complex64::ONE;
418
419 #[rustfmt::skip]
420 let data = vec![
421 o*cos, -o*sin,
422 o*sin, o*cos,
423 ];
424 DataTensor::new_from_flat(&[2, 2], data, None)
425 }
426
427 fn adjoint(&self, angles: &[f64]) -> DataTensor {
428 self.compute(&angles.iter().map(|&x| -x).collect_vec())
430 }
431}
432
433struct Rz;
435impl Gate for Rz {
436 fn name(&self) -> &str {
437 "rz"
438 }
439
440 fn compute(&self, angles: &[f64]) -> DataTensor {
441 let [theta] = unpack_angles(angles);
442 let z = Complex64::ZERO;
443 let i = Complex64::I;
444
445 #[rustfmt::skip]
446 let data = vec![
447 (-i*theta/2.0).exp(), z,
448 z, (i*theta/2.0).exp(),
449 ];
450 DataTensor::new_from_flat(&[2, 2], data, None)
451 }
452
453 fn adjoint(&self, angles: &[f64]) -> DataTensor {
454 self.compute(&angles.iter().map(|&x| -x).collect_vec())
456 }
457}
458
459struct Cx;
461impl Gate for Cx {
462 fn name(&self) -> &str {
463 "cx"
464 }
465
466 fn compute(&self, angles: &[f64]) -> DataTensor {
467 let [] = unpack_angles(angles);
468 let z = Complex64::ZERO;
469 let o = Complex64::ONE;
470 #[rustfmt::skip]
471 let data = vec![
472 o, z, z, z,
473 z, o, z, z,
474 z, z, z, o,
475 z, z, o, z,
476 ];
477 DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
478 }
479
480 fn adjoint(&self, angles: &[f64]) -> DataTensor {
481 self.compute(angles)
483 }
484}
485
486struct Cz;
488impl Gate for Cz {
489 fn name(&self) -> &str {
490 "cz"
491 }
492
493 fn compute(&self, angles: &[f64]) -> DataTensor {
494 let [] = unpack_angles(angles);
495 let z = Complex64::ZERO;
496 let o = Complex64::ONE;
497 #[rustfmt::skip]
498 let data = vec![
499 o, z, z, z,
500 z, o, z, z,
501 z, z, o, z,
502 z, z, z, -o,
503 ];
504 DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
505 }
506
507 fn adjoint(&self, angles: &[f64]) -> DataTensor {
508 self.compute(angles)
510 }
511}
512
513struct Cp;
515impl Gate for Cp {
516 fn name(&self) -> &str {
517 "cp"
518 }
519
520 fn compute(&self, angles: &[f64]) -> DataTensor {
521 let [theta] = unpack_angles(angles);
522 let z = Complex64::ZERO;
523 let o = Complex64::ONE;
524 let e = (Complex64::I * theta).exp();
525 #[rustfmt::skip]
526 let data = vec![
527 o, z, z, z,
528 z, o, z, z,
529 z, z, o, z,
530 z, z, z, e,
531 ];
532 DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
533 }
534
535 fn adjoint(&self, angles: &[f64]) -> DataTensor {
536 self.compute(&angles.iter().map(|&x| -x).collect_vec())
538 }
539}
540
541struct Iswap;
543impl Gate for Iswap {
544 fn name(&self) -> &str {
545 "iswap"
546 }
547
548 fn compute(&self, angles: &[f64]) -> DataTensor {
549 let [] = unpack_angles(angles);
550 let z = Complex64::ZERO;
551 let o = Complex64::ONE;
552 let i = Complex64::I;
553 #[rustfmt::skip]
554 let data = vec![
555 o, z, z, z,
556 z, z, i, z,
557 z, i, z, z,
558 z, z, z, o,
559 ];
560 DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
561 }
562
563 fn adjoint(&self, angles: &[f64]) -> DataTensor {
564 let mut matrix = self.compute(angles);
566 matrix.conjugate();
567 matrix
568 }
569}
570
571struct Fsim;
573impl Gate for Fsim {
574 fn name(&self) -> &str {
575 "fsim"
576 }
577
578 fn compute(&self, angles: &[f64]) -> DataTensor {
579 let [theta, phi] = unpack_angles(angles);
580 let z = Complex64::ZERO;
581 let o = Complex64::ONE;
582 let a = Complex64::new(theta.cos(), 0.0);
583 let b = Complex64::new(0.0, -theta.sin());
584 let c = Complex64::new(0.0, -phi).exp();
585 #[rustfmt::skip]
586 let data = vec![
587 o, z, z, z,
588 z, a, b, z,
589 z, b, a, z,
590 z, z, z, c,
591 ];
592 DataTensor::new_from_flat(&[2, 2, 2, 2], data, None)
593 }
594
595 fn adjoint(&self, angles: &[f64]) -> DataTensor {
596 self.compute(&angles.iter().map(|&x| -x).collect_vec())
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use std::f64::consts::PI;
604
605 use float_cmp::assert_approx_eq;
606 use rand::{distr::Uniform, prelude::Distribution, rngs::StdRng, SeedableRng};
607 use rustc_hash::FxHashMap;
608
609 use super::*;
610
611 #[test]
612 #[should_panic(expected = "Gate 'foo' not found.")]
613 fn load_unknown() {
614 let _ = load_gate("foo", &[]);
615 }
616
617 #[test]
618 #[should_panic(expected = "Gate 'foo' not found.")]
619 fn load_unknown_adjoint() {
620 let _ = load_gate_adjoint("foo", &[]);
621 }
622
623 #[test]
624 #[should_panic(expected = "Expected 0 angles, but got 2.")]
625 fn too_many_angles() {
626 let [] = unpack_angles(&[2.0, 4.0]);
627 }
628
629 #[test]
630 fn test_custom_adjoint_impls() {
631 let gate_params = FxHashMap::from_iter([
632 ("u", 3),
633 ("rx", 1),
634 ("ry", 1),
635 ("rz", 1),
636 ("cp", 1),
637 ("fsim", 2),
638 ]);
639 let rng = StdRng::seed_from_u64(42);
640 let dist = Uniform::new(-PI, PI).unwrap();
641 let rng_iter = &mut dist.sample_iter(rng);
642
643 for gate in GATES.read().unwrap().iter() {
644 let param_count = gate_params.get(gate.name()).copied().unwrap_or_default();
645 let params = rng_iter.take(param_count).collect_vec();
646 let specialized_adjoint = gate.adjoint(¶ms);
647 let mut matrix = gate.compute(¶ms);
648 matrix_adjoint_inplace(&mut matrix);
649 let general_adjoint = matrix;
650 assert_approx_eq!(&DataTensor, &specialized_adjoint, &general_adjoint);
651 }
652 }
653}