Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ class TargetInfo : public mlir::triton::proton::gpu::TargetInfoBase {
return static_cast<const mlir::triton::AMD::TargetInfo &>(helper);
}

Value clock(ConversionPatternRewriter &rewriter, Location loc,
bool isClock64) const override;

Value globalTime(ConversionPatternRewriter &rewriter,
Location loc) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ class TargetInfo : public mlir::triton::proton::gpu::TargetInfoBase {
return static_cast<const mlir::triton::NVIDIA::TargetInfo &>(helper);
}

Value clock(ConversionPatternRewriter &rewriter, Location loc,
bool isClock64) const override;

Value globalTime(ConversionPatternRewriter &rewriter,
Location loc) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ class TargetInfoBase {
return helper;
}

// Return the local cycle counter value.
virtual Value clock(ConversionPatternRewriter &rewriter, Location loc,
bool isClock64) const = 0;

// Return the global cycle counter value (i.e., synchronized across SMs) in
// nanoseconds, regardless of the clock frequency.
virtual Value globalTime(ConversionPatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ struct ReadCounterOpConversion
matchAndRewrite(mlir::triton::proton::gpu::ReadCounterOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool isClock64 = false;
auto intType = mlir::cast<IntegerType>(op.getResult().getType());
isClock64 = intType.getWidth() == 64;
Value clock = targetInfo.clock(rewriter, op.getLoc(), isClock64);
rewriter.replaceOp(op, clock);
auto time = targetInfo.globalTime(rewriter, op.getLoc());
auto trunc = arith::TruncIOp::create(rewriter, op.getLoc(), intType, time);
rewriter.replaceOp(op, trunc);
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@

namespace mlir::triton::proton::gpu::AMD {

Value TargetInfo::clock(ConversionPatternRewriter &rewriter, Location loc,
bool isClock64) const {
// NV has both a 32 bit and 64 bit clock intrinsic. On AMD we only have
// s_memtime which is 64 bit. However truncating the 64 bit version
// in cases of requesting 32 bit should be fine, since in 64 bits,
// after 0x0000.0000.ffff.ffff comes 0x0000.0001.0000.0000, and
// truncating that to 32 bits gives zero, effectively wrapping from
// 0xffff.ffff to 0x0000.0000.
auto b = TritonLLVMOpBuilder(loc, rewriter);
StringRef clock64IntrinsicName = "llvm.amdgcn.s.memtime";
Value clockVal = LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, clock64IntrinsicName, i64_ty, {})
.getResult(0);
if (!isClock64)
clockVal = LLVM::TruncOp::create(rewriter, loc, i32_ty, clockVal);

return clockVal;
}

Value TargetInfo::globalTime(ConversionPatternRewriter &rewriter,
Location loc) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,9 @@

namespace mlir::triton::proton::gpu::NVIDIA {

Value TargetInfo::clock(ConversionPatternRewriter &rewriter, Location loc,
bool isClock64) const {

auto getClockReg = [&](const std::string &clkName) {
PTXBuilder builder;
auto &movLow = builder.create("mov")->o("u32");
auto *destLowOpr = builder.newOperand("=r");
auto *sRegLowOpr = builder.newConstantOperand(clkName);
movLow(destLowOpr, sRegLowOpr);
Value clkLow32 =
builder.launch(rewriter, loc, rewriter.getIntegerType(32), true);
return clkLow32;
};

Value clkLow32 = getClockReg("%clock");

if (!isClock64)
return clkLow32;

Value clkHigh32 = getClockReg("%clock_hi");

auto b = TritonLLVMOpBuilder(loc, rewriter);
Value clkLow64 = b.zext(i64_ty, clkLow32);
Value clkHigh64 = b.zext(i64_ty, clkHigh32);
Value clock64 = b.or_(b.shl(clkHigh64, b.i64_val(32)), clkLow64);
return clock64;
}

Value TargetInfo::globalTime(ConversionPatternRewriter &rewriter,
Location loc) const {
// globaltimer is a 64-bit global clock counter in nanoseconds.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-globaltimer
auto b = TritonLLVMOpBuilder(loc, rewriter);
StringRef globalTimeIntrinsicName = "llvm.nvvm.read.ptx.sreg.globaltimer";
Value globalTimeVal = LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, globalTimeIntrinsicName, i64_ty, {})
.getResult(0);
return globalTimeVal;
return NVVM::GlobalTimerOp::create(rewriter, loc, i64_ty);
}

Value TargetInfo::processorId(ConversionPatternRewriter &rewriter,
Expand Down
34 changes: 9 additions & 25 deletions third_party/proton/common/lib/TraceDataIO/TraceWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,8 @@ using BlockTraceVec =
std::vector<const CircularLayoutParserResult::BlockTrace *>;

void populateTraceInfo(std::shared_ptr<CircularLayoutParserResult> result,
std::map<int, uint64_t> &blockToMinCycle,
std::map<int, BlockTraceVec> &procToBlockTraces) {
for (auto &bt : result->blockTraces) {
// Find the minimum cycle for each block
uint64_t minCycle = std::numeric_limits<uint64_t>::max();
for (auto &trace : bt.traces)
for (auto &event : trace.profileEvents)
if (event.first->cycle < minCycle)
minCycle = event.first->cycle;
blockToMinCycle[bt.blockId] = minCycle;

// Group block traces by proc id
int procId = bt.procId;
if (!procToBlockTraces.count(procId)) {
Expand Down Expand Up @@ -174,12 +165,10 @@ void StreamChromeTraceWriter::writeKernel(json &object,
int curColorIndex = 0;
// scope id -> color index in chrome color
std::map<int, int> scopeColor;
// block id -> min cycle observed
std::map<int, uint64_t> blockToMinCycle;
// proc id -> block traces
std::map<int, BlockTraceVec> procToBlockTraces;

populateTraceInfo(result, blockToMinCycle, procToBlockTraces);
populateTraceInfo(result, procToBlockTraces);

std::string name;
std::string pid;
Expand Down Expand Up @@ -210,17 +199,13 @@ void StreamChromeTraceWriter::writeKernel(json &object,
else
name = metadata->scopeName.at(scopeId);

// Unit: MHz, we assume freq is 1000MHz (1GHz)
double freq = 1000.0;
// All cycle values are in nanoseconds (from globalTime).
const double USEC_PER_CYCLE = 1000.0;

// Global time is in `ns` unit. With 1GHz assumption, we
// could subtract with blockToMInCycle: (ns - ns) / 1GHz - cycle
int64_t cycleAdjust =
static_cast<int64_t>(bt->initTime - minInitTime) -
static_cast<int64_t>(blockToMinCycle[ctaId]);
int64_t ts = static_cast<int64_t>(event.first->cycle) + cycleAdjust;
int64_t dur =
static_cast<int64_t>(event.second->cycle) - event.first->cycle;
int64_t ts = static_cast<int64_t>(event.first->cycle) +
static_cast<int64_t>(minInitTime);
int64_t dur = static_cast<int64_t>(event.second->cycle) -
static_cast<int64_t>(event.first->cycle);

json element;
element["cname"] = color;
Expand All @@ -229,13 +214,12 @@ void StreamChromeTraceWriter::writeKernel(json &object,
element["ph"] = "X";
element["pid"] = pid;
element["tid"] = tid;
element["ts"] = static_cast<double>(ts) / freq;
element["dur"] = static_cast<double>(dur) / freq;
element["ts"] = static_cast<double>(ts) / USEC_PER_CYCLE;
element["dur"] = static_cast<double>(dur) / USEC_PER_CYCLE;
json args;
args["Init Time (ns)"] = bt->initTime;
args["Post Final Time (ns)"] = bt->postFinalTime;
args["Finalization Time (ns)"] = bt->postFinalTime - bt->preFinalTime;
args["Frequency (MHz)"] = freq;
element["args"] = args;
element["args"]["call_stack"] = callStack;

Expand Down