1use std::{
4 borrow::Borrow,
5 f64::consts::FRAC_1_SQRT_2,
6 hash::{Hash, Hasher},
7 sync::{LazyLock, RwLock},
8};
9
10use itertools::Itertools;
11use ndarray::array;
12use num_complex::Complex64;
13use rustc_hash::FxHashSet;
14
15use crate::{tensornetwork::tensordata::DataTensor, utils::traits::ConjugateInPlace};
16
17static GATES: LazyLock<RwLock<FxHashSet<Box<dyn Gate>>>> = LazyLock::new(|| {
18 let mut gates = FxHashSet::default();
19 gates.insert(Box::new(X) as _);
20 gates.insert(Box::new(Y) as _);
21 gates.insert(Box::new(Z) as _);
22 gates.insert(Box::new(H) as _);
23 gates.insert(Box::new(T) as _);
24 gates.insert(Box::new(U) as _);
25 gates.insert(Box::new(Sx) as _);
26 gates.insert(Box::new(Sy) as _);
27 gates.insert(Box::new(Sz) as _);
28 gates.insert(Box::new(Rx) as _);
29 gates.insert(Box::new(Ry) as _);
30 gates.insert(Box::new(Rz) as _);
31 gates.insert(Box::new(Cx) as _);
32 gates.insert(Box::new(Cz) as _);
33 gates.insert(Box::new(Swap) as _);
34 gates.insert(Box::new(Cp) as _);
35 gates.insert(Box::new(Iswap) as _);
36 gates.insert(Box::new(Fsim) as _);
37 RwLock::new(gates)
38});
39
40pub fn register_gate(gate: Box<dyn Gate>) {
42 assert!(
43 gate.name().to_ascii_lowercase() == gate.name(),
44 "Gate name must be lowercase."
45 );
46 GATES.write().unwrap().insert(gate);
47}
48
49#[must_use]
51pub fn load_gate(gate: &str, angles: &[f64]) -> DataTensor {
52 let gates = &GATES.read().unwrap();
53 let gate = gates
54 .get(gate)
55 .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
56 gate.compute(angles)
57}
58
59#[must_use]
61pub fn load_gate_adjoint(gate: &str, angles: &[f64]) -> DataTensor {
62 let gates = &GATES.read().unwrap();
63 let gate = gates
64 .get(gate)
65 .unwrap_or_else(|| panic!("Gate '{gate}' not found."));
66 gate.adjoint(angles)
67}
68
69#[must_use]
71pub fn is_gate_known(gate: &str) -> bool {
72 let gates = &GATES.read().unwrap();
73 gates.contains(gate)
74}
75
76fn matrix_transpose_inplace(data: &mut DataTensor) {
84 if data.ndim() > 0 {
85 assert!(data.ndim().is_power_of_two());
86 let half = data.ndim() / 2;
87 let perm = (half..data.ndim()).chain(0..half).collect_vec();
88 let used_data = std::mem::take(data);
89 *data = used_data.permuted_axes(perm);
90 }
91}
92
93pub(crate) fn matrix_adjoint_inplace(data: &mut DataTensor) {
99 matrix_transpose_inplace(data);
100 data.conjugate();
101}
102
103#[inline]
105fn unpack_angles<const N: usize>(angles: &[f64]) -> [f64; N] {
106 angles
107 .try_into()
108 .unwrap_or_else(|_| panic!("Expected {N} angles, but got {}.", angles.len()))
109}
110
111pub trait Gate: Send + Sync {
113 fn name(&self) -> &str;
115
116 fn compute(&self, angles: &[f64]) -> DataTensor;
118
119 fn adjoint(&self, angles: &[f64]) -> DataTensor {
122 let mut matrix = self.compute(angles);
123 matrix_adjoint_inplace(&mut matrix);
124 matrix
125 }
126}
127
128impl PartialEq for dyn Gate {
129 fn eq(&self, other: &Self) -> bool {
130 self.name() == other.name()
131 }
132}
133
134impl Eq for dyn Gate {}
135
136impl Hash for dyn Gate {
137 fn hash<H: Hasher>(&self, state: &mut H) {
138 self.name().hash(state);
139 }
140}
141
142impl Borrow<str> for Box<dyn Gate> {
144 fn borrow(&self) -> &str {
145 self.name()
146 }
147}
148
149struct X;
151impl Gate for X {
152 fn name(&self) -> &str {
153 "x"
154 }
155
156 fn compute(&self, angles: &[f64]) -> DataTensor {
157 let [] = unpack_angles(angles);
158 let z = Complex64::ZERO;
159 let o = Complex64::ONE;
160 array![[z, o], [o, z],].into_dyn()
161 }
162
163 fn adjoint(&self, angles: &[f64]) -> DataTensor {
164 self.compute(angles)
166 }
167}
168
169struct Y;
171impl Gate for Y {
172 fn name(&self) -> &str {
173 "y"
174 }
175
176 fn compute(&self, angles: &[f64]) -> DataTensor {
177 let [] = unpack_angles(angles);
178 let z = Complex64::ZERO;
179 let i = Complex64::I;
180 array![[z, -i], [i, z]].into_dyn()
181 }
182
183 fn adjoint(&self, angles: &[f64]) -> DataTensor {
184 self.compute(angles)
186 }
187}
188
189struct Z;
191impl Gate for Z {
192 fn name(&self) -> &str {
193 "z"
194 }
195
196 fn compute(&self, angles: &[f64]) -> DataTensor {
197 let [] = unpack_angles(angles);
198 let z = Complex64::ZERO;
199 let o = Complex64::ONE;
200 array![[o, z], [z, -o]].into_dyn()
201 }
202
203 fn adjoint(&self, angles: &[f64]) -> DataTensor {
204 self.compute(angles)
206 }
207}
208
209struct H;
211impl Gate for H {
212 fn name(&self) -> &str {
213 "h"
214 }
215
216 fn compute(&self, angles: &[f64]) -> DataTensor {
217 let [] = unpack_angles(angles);
218 let h = Complex64::new(FRAC_1_SQRT_2, 0.0);
219 array![[h, h], [h, -h]].into_dyn()
220 }
221
222 fn adjoint(&self, angles: &[f64]) -> DataTensor {
223 self.compute(angles)
225 }
226}
227
228struct T;
230impl Gate for T {
231 fn name(&self) -> &str {
232 "t"
233 }
234
235 fn compute(&self, angles: &[f64]) -> DataTensor {
236 let [] = unpack_angles(angles);
237 let z = Complex64::ZERO;
238 let o = Complex64::ONE;
239 let t = Complex64::new(FRAC_1_SQRT_2, FRAC_1_SQRT_2);
240 array![[o, z], [z, t]].into_dyn()
241 }
242
243 fn adjoint(&self, angles: &[f64]) -> DataTensor {
244 let mut matrix = self.compute(angles);
246 matrix.conjugate();
247 matrix
248 }
249}
250
251struct U;
253impl Gate for U {
254 fn name(&self) -> &str {
255 "u"
256 }
257
258 fn compute(&self, angles: &[f64]) -> DataTensor {
259 let [theta, phi, lambda] = unpack_angles(angles);
260 let (sin, cos) = (theta / 2.0).sin_cos();
261 array![
262 [
263 Complex64::new(cos, 0.0),
264 -(Complex64::I * lambda).exp() * sin
265 ],
266 [
267 (Complex64::I * phi).exp() * sin,
268 (Complex64::I * (phi + lambda)).exp() * cos
269 ]
270 ]
271 .into_dyn()
272 }
273
274 fn adjoint(&self, angles: &[f64]) -> DataTensor {
275 let [theta, phi, lambda] = unpack_angles(angles);
277 let (sin, cos) = (theta / 2.0).sin_cos();
278 array![
279 [Complex64::new(cos, 0.0), (Complex64::I * -phi).exp() * sin],
280 [
281 -(Complex64::I * -lambda).exp() * sin,
282 (Complex64::I * -(phi + lambda)).exp() * cos
283 ]
284 ]
285 .into_dyn()
286 }
287}
288
289struct Sx;
291impl Gate for Sx {
292 fn name(&self) -> &str {
293 "sx"
294 }
295
296 fn compute(&self, angles: &[f64]) -> DataTensor {
297 let [] = unpack_angles(angles);
298 let a = Complex64::new(0.5, 0.5);
299 let b = Complex64::new(0.5, -0.5);
300 array![[a, b], [b, a]].into_dyn()
301 }
302
303 fn adjoint(&self, angles: &[f64]) -> DataTensor {
304 let mut matrix = self.compute(angles);
306 matrix.conjugate();
307 matrix
308 }
309}
310
311struct Sy;
313impl Gate for Sy {
314 fn name(&self) -> &str {
315 "sy"
316 }
317
318 fn compute(&self, angles: &[f64]) -> DataTensor {
319 let [] = unpack_angles(angles);
320 let a = Complex64::new(0.5, 0.5);
321 let b = Complex64::new(-0.5, -0.5);
322 array![[a, b], [a, a]].into_dyn()
323 }
324}
325
326struct Sz;
328impl Gate for Sz {
329 fn name(&self) -> &str {
330 "sz"
331 }
332
333 fn compute(&self, angles: &[f64]) -> DataTensor {
334 let [] = unpack_angles(angles);
335 let z = Complex64::ZERO;
336 let o = Complex64::ONE;
337 let i = Complex64::I;
338 array![[o, z], [z, i]].into_dyn()
339 }
340
341 fn adjoint(&self, angles: &[f64]) -> DataTensor {
342 let mut matrix = self.compute(angles);
344 matrix.conjugate();
345 matrix
346 }
347}
348
349struct Rx;
351impl Gate for Rx {
352 fn name(&self) -> &str {
353 "rx"
354 }
355
356 fn compute(&self, angles: &[f64]) -> DataTensor {
357 let [theta] = unpack_angles(angles);
358 let (sin, cos) = (theta / 2.0).sin_cos();
359 let o = Complex64::ONE;
360 let i = Complex64::I;
361 array![[o * cos, -i * sin], [-i * sin, o * cos]].into_dyn()
362 }
363
364 fn adjoint(&self, angles: &[f64]) -> DataTensor {
365 self.compute(&angles.iter().map(|&x| -x).collect_vec())
367 }
368}
369
370struct Ry;
372impl Gate for Ry {
373 fn name(&self) -> &str {
374 "ry"
375 }
376
377 fn compute(&self, angles: &[f64]) -> DataTensor {
378 let [theta] = unpack_angles(angles);
379 let (sin, cos) = (theta / 2.0).sin_cos();
380 let o = Complex64::ONE;
381
382 array![[o * cos, -o * sin], [o * sin, o * cos]].into_dyn()
383 }
384
385 fn adjoint(&self, angles: &[f64]) -> DataTensor {
386 self.compute(&angles.iter().map(|&x| -x).collect_vec())
388 }
389}
390
391struct Rz;
393impl Gate for Rz {
394 fn name(&self) -> &str {
395 "rz"
396 }
397
398 fn compute(&self, angles: &[f64]) -> DataTensor {
399 let [theta] = unpack_angles(angles);
400 let z = Complex64::ZERO;
401 let i = Complex64::I;
402
403 array![[(-i * theta / 2.0).exp(), z], [z, (i * theta / 2.0).exp()]].into_dyn()
404 }
405
406 fn adjoint(&self, angles: &[f64]) -> DataTensor {
407 self.compute(&angles.iter().map(|&x| -x).collect_vec())
409 }
410}
411
412struct Cx;
414impl Gate for Cx {
415 fn name(&self) -> &str {
416 "cx"
417 }
418
419 fn compute(&self, angles: &[f64]) -> DataTensor {
420 let [] = unpack_angles(angles);
421 let z = Complex64::ZERO;
422 let o = Complex64::ONE;
423 array![[o, z, z, z], [z, o, z, z], [z, z, z, o], [z, z, o, z]]
424 .into_shape_with_order((2, 2, 2, 2))
425 .unwrap()
426 .into_dyn()
427 }
428
429 fn adjoint(&self, angles: &[f64]) -> DataTensor {
430 self.compute(angles)
432 }
433}
434
435struct Cz;
437impl Gate for Cz {
438 fn name(&self) -> &str {
439 "cz"
440 }
441
442 fn compute(&self, angles: &[f64]) -> DataTensor {
443 let [] = unpack_angles(angles);
444 let z = Complex64::ZERO;
445 let o = Complex64::ONE;
446 array![[o, z, z, z], [z, o, z, z], [z, z, o, z], [z, z, z, -o]]
447 .into_shape_with_order((2, 2, 2, 2))
448 .unwrap()
449 .into_dyn()
450 }
451
452 fn adjoint(&self, angles: &[f64]) -> DataTensor {
453 self.compute(angles)
455 }
456}
457
458struct Swap;
460impl Gate for Swap {
461 fn name(&self) -> &str {
462 "swap"
463 }
464
465 fn compute(&self, angles: &[f64]) -> DataTensor {
466 let [] = unpack_angles(angles);
467 let z = Complex64::ZERO;
468 let o = Complex64::ONE;
469 array![[o, z, z, z], [z, z, o, z], [z, o, z, z], [z, z, z, o]]
470 .into_shape_with_order((2, 2, 2, 2))
471 .unwrap()
472 .into_dyn()
473 }
474
475 fn adjoint(&self, angles: &[f64]) -> DataTensor {
476 self.compute(angles)
478 }
479}
480
481struct Cp;
483impl Gate for Cp {
484 fn name(&self) -> &str {
485 "cp"
486 }
487
488 fn compute(&self, angles: &[f64]) -> DataTensor {
489 let [theta] = unpack_angles(angles);
490 let z = Complex64::ZERO;
491 let o = Complex64::ONE;
492 let e = (Complex64::I * theta).exp();
493 array![[o, z, z, z], [z, o, z, z], [z, z, o, z], [z, z, z, e]]
494 .into_shape_with_order((2, 2, 2, 2))
495 .unwrap()
496 .into_dyn()
497 }
498
499 fn adjoint(&self, angles: &[f64]) -> DataTensor {
500 self.compute(&angles.iter().map(|&x| -x).collect_vec())
502 }
503}
504
505struct Iswap;
507impl Gate for Iswap {
508 fn name(&self) -> &str {
509 "iswap"
510 }
511
512 fn compute(&self, angles: &[f64]) -> DataTensor {
513 let [] = unpack_angles(angles);
514 let z = Complex64::ZERO;
515 let o = Complex64::ONE;
516 let i = Complex64::I;
517 array![[o, z, z, z], [z, z, i, z], [z, i, z, z], [z, z, z, o]]
518 .into_shape_with_order((2, 2, 2, 2))
519 .unwrap()
520 .into_dyn()
521 }
522
523 fn adjoint(&self, angles: &[f64]) -> DataTensor {
524 let mut matrix = self.compute(angles);
526 matrix.conjugate();
527 matrix
528 }
529}
530
531struct Fsim;
533impl Gate for Fsim {
534 fn name(&self) -> &str {
535 "fsim"
536 }
537
538 fn compute(&self, angles: &[f64]) -> DataTensor {
539 let [theta, phi] = unpack_angles(angles);
540 let z = Complex64::ZERO;
541 let o = Complex64::ONE;
542 let a = Complex64::new(theta.cos(), 0.0);
543 let b = Complex64::new(0.0, -theta.sin());
544 let c = Complex64::new(0.0, -phi).exp();
545 array![[o, z, z, z], [z, a, b, z], [z, b, a, z], [z, z, z, c]]
546 .into_shape_with_order((2, 2, 2, 2))
547 .unwrap()
548 .into_dyn()
549 }
550
551 fn adjoint(&self, angles: &[f64]) -> DataTensor {
552 self.compute(&angles.iter().map(|&x| -x).collect_vec())
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use std::f64::consts::PI;
560
561 use approx::assert_abs_diff_eq;
562 use rand::{distr::Uniform, prelude::Distribution, rngs::StdRng, SeedableRng};
563 use rustc_hash::FxHashMap;
564
565 use super::*;
566
567 #[test]
568 #[should_panic(expected = "Gate 'foo' not found.")]
569 fn load_unknown() {
570 let _ = load_gate("foo", &[]);
571 }
572
573 #[test]
574 #[should_panic(expected = "Gate 'foo' not found.")]
575 fn load_unknown_adjoint() {
576 let _ = load_gate_adjoint("foo", &[]);
577 }
578
579 #[test]
580 #[should_panic(expected = "Expected 0 angles, but got 2.")]
581 fn too_many_angles() {
582 let [] = unpack_angles(&[2.0, 4.0]);
583 }
584
585 #[test]
586 fn test_custom_adjoint_impls() {
587 let gate_params = FxHashMap::from_iter([
588 ("u", 3),
589 ("rx", 1),
590 ("ry", 1),
591 ("rz", 1),
592 ("cp", 1),
593 ("fsim", 2),
594 ]);
595 let rng = StdRng::seed_from_u64(42);
596 let dist = Uniform::new(-PI, PI).unwrap();
597 let rng_iter = &mut dist.sample_iter(rng);
598
599 for gate in GATES.read().unwrap().iter() {
600 let param_count = gate_params.get(gate.name()).copied().unwrap_or_default();
601 let params = rng_iter.take(param_count).collect_vec();
602 let specialized_adjoint = gate.adjoint(¶ms);
603 let mut matrix = gate.compute(¶ms);
604 matrix_adjoint_inplace(&mut matrix);
605 let general_adjoint = matrix;
606 assert_abs_diff_eq!(&specialized_adjoint, &general_adjoint);
607 }
608 }
609}