Skip to content

Commit c1d534e

Browse files
bbarbakadzeagnesLeroy
authored andcommitted
refactor(gpu): refactor double2 operators to use cuda intrinsics
1 parent 47589ea commit c1d534e

File tree

1 file changed

+32
-44
lines changed

1 file changed

+32
-44
lines changed

backends/tfhe-cuda-backend/cuda/src/types/complex/operations.cuh

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
#include <cstdint>
55
#include <cstdio>
66

7-
#define SNT 1
8-
#define dPI 6.283185307179586231995926937088
9-
10-
using sTorus = int32_t;
11-
// using Torus = uint32_t;
127
using sTorus = int32_t;
138
using u32 = uint32_t;
149
using i32 = int32_t;
@@ -17,73 +12,66 @@ using i32 = int32_t;
1712
// Basic double2 operations
1813

1914
__device__ inline double2 conjugate(const double2 num) {
20-
double2 res;
21-
res.x = num.x;
22-
res.y = -num.y;
23-
return res;
15+
return {num.x, -num.y};
2416
}
2517

2618
__device__ inline void operator+=(double2 &lh, const double2 rh) {
27-
lh.x += rh.x;
28-
lh.y += rh.y;
19+
lh.x = __dadd_rn(lh.x, rh.x);
20+
lh.y = __dadd_rn(lh.y, rh.y);
2921
}
3022

3123
__device__ inline void operator-=(double2 &lh, const double2 rh) {
32-
lh.x -= rh.x;
33-
lh.y -= rh.y;
24+
lh.x = __dsub_rn(lh.x, rh.x);
25+
lh.y = __dsub_rn(lh.y, rh.y);
3426
}
3527

3628
__device__ inline double2 operator+(const double2 a, const double2 b) {
37-
double2 res;
38-
res.x = a.x + b.x;
39-
res.y = a.y + b.y;
40-
return res;
29+
return {__dadd_rn(a.x, b.x), __dadd_rn(a.y, b.y)};
4130
}
4231

4332
__device__ inline double2 operator-(const double2 a, const double2 b) {
44-
double2 res;
45-
res.x = a.x - b.x;
46-
res.y = a.y - b.y;
47-
return res;
33+
return {__dsub_rn(a.x, b.x), __dsub_rn(a.y, b.y)};
4834
}
4935

36+
// Fused multiply-add/subtract for complex multiplication
5037
__device__ inline double2 operator*(const double2 a, const double2 b) {
51-
double2 res;
52-
res.x = (a.y * -b.y) + (a.x * b.x);
53-
res.y = (a.x * b.y) + (a.y * b.x);
54-
return res;
38+
return {
39+
__fma_rn(a.x, b.x,
40+
-__dmul_rn(a.y, b.y)), // Real part: a.x * b.x - a.y * b.y
41+
__fma_rn(a.x, b.y,
42+
__dmul_rn(a.y, b.x)) // Imaginary part: a.x * b.y + a.y * b.x
43+
};
5544
}
5645

57-
__device__ inline double2 operator*(const double2 a, double b) {
58-
double2 res;
59-
res.x = a.x * b;
60-
res.y = a.y * b;
61-
return res;
46+
// Fused complex multiplication assignment (avoiding temporary storage)
47+
__device__ inline void operator*=(double2 &a, const double2 b) {
48+
double real = __fma_rn(a.x, b.x, -__dmul_rn(a.y, b.y));
49+
a.y = __fma_rn(
50+
a.x, b.y,
51+
__dmul_rn(a.y,
52+
b.x)); // Update imag first to prevent register reuse issues
53+
a.x = real;
6254
}
6355

64-
__device__ inline void operator*=(double2 &a, const double2 b) {
65-
double tmp = a.x;
66-
a.x *= b.x;
67-
a.x -= a.y * b.y;
68-
a.y *= b.x;
69-
a.y += b.y * tmp;
56+
__device__ inline double2 operator*(const double2 a, double b) {
57+
return {__dmul_rn(a.x, b), __dmul_rn(a.y, b)};
7058
}
7159

60+
// Direct multiplication with scalar
7261
__device__ inline void operator*=(double2 &a, const double b) {
73-
a.x *= b;
74-
a.y *= b;
62+
a.x = __dmul_rn(a.x, b);
63+
a.y = __dmul_rn(a.y, b);
7564
}
7665

66+
// Fused division (could be improved with reciprocal if division is frequent)
7767
__device__ inline void operator/=(double2 &a, const double b) {
78-
a.x /= b;
79-
a.y /= b;
68+
double inv_b = __drcp_rn(b); // Use reciprocal for faster division
69+
a.x = __dmul_rn(a.x, inv_b);
70+
a.y = __dmul_rn(a.y, inv_b);
8071
}
8172

8273
__device__ inline double2 operator*(double a, double2 b) {
83-
double2 res;
84-
res.x = b.x * a;
85-
res.y = b.y * a;
86-
return res;
74+
return {__dmul_rn(b.x, a), __dmul_rn(b.y, a)};
8775
}
8876

8977
#endif

0 commit comments

Comments
 (0)