Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,13 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
src_value = Cast(dst->dtype, src_value);

// Build a pointer to destination element using tvm_access_ptr
PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, dst_indices)});
BufferLoad dst_load = BufferLoad(dst, dst_indices);
Array<Range> dst_ranges;
for (const PrimExpr &index : dst_indices) {
dst_ranges.push_back(Range::FromMinExtent(index, 1));
}
BufferRegion dst_region = BufferRegion(dst, dst_ranges);
PrimExpr dst_ptr = MakeAccessPtrFromRegion(dst_region, 2); // 2 = write access

new_args.push_back(dst_ptr);
new_args.push_back(src_value);
Expand Down Expand Up @@ -381,15 +386,22 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< "src_size = " << src_size << ", dst_size = " << dst_size;
BufferLoad src_node = BufferLoad(src, src_indices);
BufferLoad dst_node = BufferLoad(dst, dst_indices);
Call address_of_src =
Call(DataType::Handle(), builtin::address_of(), {src_node});
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Array<Range> src_ranges, dst_ranges;
for (const PrimExpr &index : src_indices) {
src_ranges.push_back(Range::FromMinExtent(index, 1));
}
for (const PrimExpr &index : dst_indices) {
dst_ranges.push_back(Range::FromMinExtent(index, 1));
}
BufferRegion src_region = BufferRegion(src, src_ranges);
BufferRegion dst_region = BufferRegion(dst, dst_ranges);
PrimExpr src_ptr = MakeAccessPtrFromRegion(src_region, 1); // 1 = read access
PrimExpr dst_ptr = MakeAccessPtrFromRegion(dst_region, 2); // 2 = write access

int need_reduce = 1;
int eviction_policy = 0;
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
{address_of_src, address_of_dst,
{src_ptr, dst_ptr,
ceildiv(src_size * src->dtype.bits(), 8),
need_reduce, eviction_policy}));
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
Expand Down
44 changes: 29 additions & 15 deletions src/transform/atomicadd_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

#include "atomicadd_vectorize.h"
#include "../op/utils.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -233,22 +234,30 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
const IntImm memory_order =
node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
Array<PrimExpr> new_args;
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node});
// Convert BufferLoad to access_ptr
Array<Range> dst_ranges, value_ranges;
for (const PrimExpr &index : dst_node->indices) {
dst_ranges.push_back(Range::FromMinExtent(index, 1));
}
for (const PrimExpr &index : value_node->indices) {
value_ranges.push_back(Range::FromMinExtent(index, 1));
}
BufferRegion dst_region = BufferRegion(dst_node->buffer, dst_ranges);
BufferRegion value_region = BufferRegion(value_node->buffer, value_ranges);
PrimExpr dst_ptr = MakeAccessPtrFromRegion(dst_region, 2); // 2 = write access
PrimExpr value_ptr = MakeAccessPtrFromRegion(value_region, 1); // 1 = read access
if (vector_size_ == 4) {
new_args.push_back(StringImm("AtomicAddx4"));
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
new_args.push_back(dst_ptr);
new_args.push_back(value_ptr);
} else if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2"));
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
new_args.push_back(dst_ptr);
new_args.push_back(value_ptr);
} else {
// Scalar case: AtomicAdd now expects a pointer to destination.
new_args.push_back(StringImm("AtomicAdd"));
new_args.push_back(address_of_dst);
new_args.push_back(dst_ptr);
new_args.push_back(value_node);
}
new_args.push_back(memory_order);
Expand All @@ -263,13 +272,18 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
// Ensure first argument is an address; keep value as-is.
if (!node->args.empty()) {
if (const auto *bl = node->args[0].as<BufferLoadNode>()) {
Call address_of_dst = Call(DataType::Handle(), builtin::address_of(),
{Downcast<BufferLoad>(node->args[0])});
new_args.push_back(address_of_dst);
// Convert BufferLoad to access_ptr
Array<Range> dst_ranges;
for (const PrimExpr &index : bl->indices) {
dst_ranges.push_back(Range::FromMinExtent(index, 1));
}
BufferRegion dst_region = BufferRegion(bl->buffer, dst_ranges);
PrimExpr dst_ptr = MakeAccessPtrFromRegion(dst_region, 2); // 2 = write access
new_args.push_back(dst_ptr);
} else if (const auto *call = node->args[0].as<CallNode>()) {
// If it's already an address_of, forward it; otherwise, keep
// original.
if (call->op.same_as(builtin::address_of())) {
// If it's already an address_of or access_ptr, forward it; otherwise, keep original.
if (call->op.same_as(builtin::address_of()) ||
call->op.same_as(builtin::tvm_access_ptr())) {
new_args.push_back(node->args[0]);
} else {
new_args.push_back(node->args[0]);
Expand Down
68 changes: 68 additions & 0 deletions src/transform/thread_storage_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,70 @@ class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
StorageScope sync_scope_;
};

// This class adds syncthreads after AtomicAdd operations on shared memory.
// This is needed because atomic operations on shared memory require
// synchronization before subsequent reads/writes to ensure visibility.
class ThreadSyncAfterAtomicInserter : public StmtExprMutator {
public:
explicit ThreadSyncAfterAtomicInserter(StorageScope sync_scope)
: sync_scope_(std::move(sync_scope)) {}

Stmt VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::call_extern()) && call->args.size() >= 1) {
if (const auto *func_name = call->args[0].as<StringImmNode>()) {
std::string name = func_name->value;
// Check if this is an AtomicAdd call (AtomicAdd, AtomicAddx2, AtomicAddx4, etc.)
if (name == "AtomicAdd" || name == "AtomicAddx2" || name == "AtomicAddx4" ||
name == "AtomicAddRet" || name == "AtomicAddx2Ret" || name == "AtomicAddx4Ret") {
// Check if the first argument (destination pointer) is an access_ptr to shared memory
if (call->args.size() >= 2) {
if (const auto *ptr_call = call->args[1].as<CallNode>()) {
if (ptr_call->op.same_as(builtin::tvm_access_ptr()) &&
ptr_call->args.size() >= 2) {
if (const auto *buffer_var = ptr_call->args[1].as<VarNode>()) {
StorageScope buffer_scope = StorageScope::Create(
GetPtrStorageScope(tvm::ffi::GetRef<Var>(buffer_var)));
// Check if the buffer is in shared memory scope
if (sync_scope_.rank == StorageRank::kShared &&
buffer_scope.rank == StorageRank::kShared) {
// Insert sync after the atomic operation
auto sync = Evaluate(Call(DataType::Int(32),
builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
auto ret = StmtExprMutator::VisitStmt_(op);
return SeqStmt({ret, sync});
}
}
} else if (ptr_call->op.same_as(builtin::address_of()) &&
ptr_call->args.size() >= 1) {
// Handle legacy address_of case (for backward compatibility)
if (const auto *load = ptr_call->args[0].as<BufferLoadNode>()) {
StorageScope buffer_scope = StorageScope::Create(
GetPtrStorageScope(load->buffer->data));
if (sync_scope_.rank == StorageRank::kShared &&
buffer_scope.rank == StorageRank::kShared) {
auto sync = Evaluate(Call(DataType::Int(32),
builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
auto ret = StmtExprMutator::VisitStmt_(op);
return SeqStmt({ret, sync});
}
}
}
}
}
}
}
}
}
return StmtExprMutator::VisitStmt_(op);
}

private:
StorageScope sync_scope_;
};

class ThreadSyncInserter : public StmtExprMutator {
public:
ThreadSyncInserter(StorageScope sync_scope,
Expand Down Expand Up @@ -826,6 +890,10 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {

stmt =
ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
// Insert sync after AtomicAdd operations on shared memory
if (sync_scope.rank == StorageRank::kShared) {
stmt = ThreadSyncAfterAtomicInserter(sync_scope)(std::move(stmt));
}
n->body = ThreadPartialSyncRewriter::Rewrite(std::move(stmt));
return func;
}
Expand Down
Loading