Skip to content

Commit 2d8d367

Browse files
authored
[CUDA] Enhance Broadcast Codegen for Symbolic Value (#1669)
* Enhance CUDA code generation for BroadcastNode by implementing compile-time constant folding and runtime broadcasting for various lane configurations. Improved handling for 4-bit and 8-bit integer types, ensuring correct replication and type casting in output expressions. This update increases performance and correctness in CUDA kernel generation. * add test * lint fix * fix
1 parent 4084dcd commit 2d8d367

File tree

2 files changed

+46
-26
lines changed

2 files changed

+46
-26
lines changed

src/target/codegen_cuda.cc

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3193,32 +3193,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
31933193
std::ostream &os) { // NOLINT(*)
31943194
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
31953195
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8) {
3196-
if (lanes == 4) {
3197-
// make_int8x4
3198-
const int64_t *p = as_const_int(op->value);
3199-
ICHECK(p);
3200-
int64_t v = *p & 0xFF;
3201-
v = (v << 24) | (v << 16) | (v << 8) | v;
3202-
if (op->dtype.is_uint()) {
3203-
os << "(uint)" << v;
3204-
} else {
3205-
os << "(int)" << v;
3206-
}
3207-
return;
3208-
} else if (lanes == 32) {
3209-
// make_int8x32
3210-
const int64_t *p = as_const_int(op->value);
3211-
ICHECK(p);
3212-
int64_t v = *p & 0xFF;
3213-
v = (v << 24) | (v << 16) | (v << 8) | v;
3214-
if (op->dtype.is_uint()) {
3215-
os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v
3216-
<< ")";
3217-
} else {
3218-
os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v
3219-
<< ")";
3196+
const int64_t *p = as_const_int(op->value);
3197+
if (p) {
3198+
if (lanes == 4) {
3199+
// make_int8x4
3200+
ICHECK(p);
3201+
int64_t v = *p & 0xFF;
3202+
v = (v << 24) | (v << 16) | (v << 8) | v;
3203+
if (op->dtype.is_uint()) {
3204+
os << "(uint)" << v;
3205+
} else {
3206+
os << "(int)" << v;
3207+
}
3208+
return;
3209+
} else if (lanes == 32) {
3210+
// make_int8x32
3211+
const int64_t *p = as_const_int(op->value);
3212+
ICHECK(p);
3213+
int64_t v = *p & 0xFF;
3214+
v = (v << 24) | (v << 16) | (v << 8) | v;
3215+
if (op->dtype.is_uint()) {
3216+
os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v
3217+
<< ")";
3218+
} else {
3219+
os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v
3220+
<< ")";
3221+
}
3222+
return;
32203223
}
3221-
return;
32223224
}
32233225
}
32243226

@@ -3284,7 +3286,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
32843286
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
32853287
bool fail = false;
32863288
const int64_t *p = as_const_int(op->value);
3287-
ICHECK(p);
3289+
ICHECK(p) << "BroadcastNode " << op << " value: " << op->value
3290+
<< " is not a constant";
32883291
int64_t v = *p & 0xF;
32893292

32903293
if (lanes == 4) {

testing/python/language/test_tilelang_language_vectorize.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,22 @@ def test_vectorize_all_dtypes(dtype, vec_num):
148148
kernel(x)
149149

150150

151+
@tilelang.jit
152+
def vectorize_broadcast_int8(vec_num):
153+
with T.Kernel(1, threads=128):
154+
a = T.alloc_local((64,), "int8")
155+
b = T.alloc_var("int8")
156+
157+
for i in T.vectorized(vec_num):
158+
a[i] = b
159+
160+
161+
@tilelang.testing.requires_cuda
162+
@pytest.mark.parametrize("vec_num", [4, 32])
163+
def test_vectorize_broadcast_int8(vec_num):
164+
"""Test broadcasting a non-constant int8 value to a vectorized store."""
165+
vectorize_broadcast_int8.compile(vec_num=vec_num)
166+
167+
151168
if __name__ == "__main__":
152169
tilelang.testing.main()

0 commit comments

Comments
 (0)