Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ __intrinsic_op($(kIROp_AlignedAttr))
internal int __align_attr(int alignment);

__intrinsic_op($(kIROp_Load))
internal T __load_aligned<T>(T* ptr, int alignmentAttr);
internal T __load_aligned<T, Access access>(Ptr<T, access, AddressSpace.Device> ptr, int alignmentAttr);

__intrinsic_op($(kIROp_Store))
internal void __store_aligned<T>(T* ptr, T value, int alignmentAttr);
Expand All @@ -1443,7 +1443,7 @@ internal void __store_aligned<T>(T* ptr, T value, int alignmentAttr);
///
[__NoSideEffect]
[ForceInline]
T loadAligned<int alignment, T>(T* ptr)
T loadAligned<int alignment, T, Access access>(Ptr<T, access, AddressSpace.Device> ptr)
{
return __load_aligned(ptr, __align_attr(alignment));
}
Expand Down
10 changes: 10 additions & 0 deletions source/slang/slang-emit-cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(
SLANG_UNUSED(entryPointDecor);
}

void CUDASourceEmitter::emitPostKeywordTypeAttributesImpl(IRInst* inst)
{
if (auto alignmentDecor = inst->findDecoration<IRAlignmentDecoration>())
{
m_writer->emit("__align__(");
m_writer->emit(getIntVal(alignmentDecor->getAlignmentOperand()));
m_writer->emit(") ");
}
}

void CUDASourceEmitter::emitFunctionPreambleImpl(IRInst* inst)
{
if (!inst)
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-emit-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class CUDASourceEmitter : public CPPSourceEmitter
virtual void emitEntryPointAttributesImpl(
IRFunc* irFunc,
IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) SLANG_OVERRIDE;

virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace)
SLANG_OVERRIDE;
Expand Down
4 changes: 3 additions & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,9 +1924,11 @@ Result linkAndOptimizeIR(

// If we are generating code for CUDA, we should translate all immutable buffer loads to
// using `__ldg` intrinsic for improved performance.
// For aligned loads, we will also create a wrapper struct type with an alignment decorator,
// and rewrite the load to load the wrapper struct.
if (isCUDATarget(targetRequest))
{
SLANG_PASS(lowerImmutableBufferLoadForCUDA, targetProgram);
SLANG_PASS(lowerImmutableOrAlignedBufferLoadForCUDA, targetProgram);
}

SLANG_PASS(performForceInlining);
Expand Down
107 changes: 100 additions & 7 deletions source/slang/slang-ir-cuda-immutable-load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,52 @@ struct LoadMethod
struct ImmutableBufferLoadLoweringContext : InstPassBase
{
Dictionary<IRType*, LoadMethod> loadFuncs;

struct AlignedTypeWrapperKey
{
IRType* innerType;
IRIntegerValue alignment;
HashCode getHashCode() const
{
return combineHash(Slang::getHashCode(innerType), Slang::getHashCode(alignment));
}
bool operator==(const AlignedTypeWrapperKey& other) const
{
return innerType == other.innerType && alignment == other.alignment;
}
};

Dictionary<AlignedTypeWrapperKey, IRType*> alignedWrapperTypes;

IRType* getOrCreateAlignedWrapper(IRType* innerType, IRIntegerValue alignment)
{
IRSizeAndAlignment naturalSizeAlignment;
getNaturalSizeAndAlignment(targetProgram->getTargetReq(), innerType, &naturalSizeAlignment);
if (naturalSizeAlignment.alignment >= alignment)
return innerType;

auto key = AlignedTypeWrapperKey{innerType, alignment};
if (auto* wrappedType = alignedWrapperTypes.tryGetValue(key))
return *wrappedType;

IRBuilder builder(innerType);
builder.setInsertAfter(innerType);

auto structType = builder.createStructType();
StringBuilder nameSb;
getTypeNameHint(nameSb, innerType);
nameSb << "_aligned" << alignment;
builder.addNameHintDecoration(structType, nameSb.getUnownedSlice());

auto fieldKey = builder.createStructKey();
builder.addNameHintDecoration(fieldKey, toSlice("val"));
builder.createStructField(structType, fieldKey, innerType);
builder.addAlignmentDecoration(structType, alignment);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we re-use IRAlignedAttr as the decoration?


alignedWrapperTypes[key] = structType;
return structType;
}

TargetProgram* targetProgram;

IRFunc* createLoadFunc(IRBuilder& builder, IRType* valueType, IRParam*& outParam)
Expand Down Expand Up @@ -300,16 +346,63 @@ struct ImmutableBufferLoadLoweringContext : InstPassBase
case kIROp_Load:
{
auto load = as<IRLoad>(inst);
if (isPointerToImmutableLocation(getRootAddr(load->getPtr())))
IRInst* loadedValue = load;

IRBuilder builder(load);
builder.setInsertBefore(load);
auto rootAddr = getRootAddr(load->getPtr());
auto alignmentAttr = load->findAttr<IRAlignedAttr>();
bool needUnwrap = false;
if (alignmentAttr)
{
auto wrappedType = getOrCreateAlignedWrapper(
load->getDataType(),
getIntVal(alignmentAttr->getAlignment()));
if (wrappedType != load->getDataType())
{
auto newPtr = builder.emitBitCast(
builder.getPtrType(
kIROp_PtrType,
wrappedType,
as<IRPtrTypeBase>(load->getPtr()->getDataType())),
load->getPtr());
builder.replaceOperand(load->getPtrOperand(), newPtr);
load->setFullType(wrappedType);
needUnwrap = true;
}
}
if (isPointerToImmutableLocation(rootAddr))
{
if (auto immutableLoad = emitImmutableLoad(builder, load->getPtr()))
{
loadedValue = immutableLoad;
}
}

if (needUnwrap)
{
IRBuilder builder(load);
builder.setInsertBefore(load);
if (auto newLoad = emitImmutableLoad(builder, load->getPtr()))
List<IRUse*> uses;
for (auto use = inst->firstUse; use; use = use->nextUse)
{
uses.add(use);
}
auto wrappedStruct = as<IRStructType>(loadedValue->getDataType());
SLANG_ASSERT(wrappedStruct);
auto firstField = wrappedStruct->getFields().getFirst();
SLANG_ASSERT(firstField);
auto fieldKey = firstField->getKey();
builder.setInsertAfter(loadedValue);
loadedValue = builder.emitFieldExtract(loadedValue, fieldKey);
for (auto use : uses)
{
load->replaceUsesWith(newLoad);
load->removeAndDeallocate();
builder.replaceOperand(use, loadedValue);
}
}
else if (loadedValue != inst)
{
inst->replaceUsesWith(loadedValue);
inst->removeAndDeallocate();
}
}
break;
case kIROp_StructuredBufferLoad:
Expand Down Expand Up @@ -365,7 +458,7 @@ struct ImmutableBufferLoadLoweringContext : InstPassBase
}
};

void lowerImmutableBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram)
void lowerImmutableOrAlignedBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram)
{
ImmutableBufferLoadLoweringContext context(module);
context.targetProgram = targetProgram;
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-cuda-immutable-load.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ for better performance.
struct IRModule;
class TargetProgram;

void lowerImmutableBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram);
void lowerImmutableOrAlignedBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram);

} // namespace Slang
1 change: 1 addition & 0 deletions source/slang/slang-ir-insts-stable-names.lua
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,5 @@ return {
["Type.OptionalNoneType"] = 725,
["ReinterpretOptional"] = 726,
["Type.DefaultPushConstantLayout"] = 727,
["Decoration.AlignmentDecoration"] = 728,
}
17 changes: 16 additions & 1 deletion source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ struct IRLoad : IRInst
{
FIDDLE(leafInst())
IRUse ptr;

IRUse* getPtrOperand() { return &ptr; }
IRInst* getPtr() { return ptr.get(); }
};

Expand Down Expand Up @@ -3344,6 +3344,16 @@ struct IRBuilder
oldPtrType->getAccessQualifier(),
oldPtrType->getAddressSpace());
}
// Copies the access-qualifier and address-space from oldPtrType. Use new pointer type opcode
// and value type.
IRPtrTypeBase* getPtrType(IROp ptrTypeOp, IRType* valueType, IRPtrTypeBase* oldPtrType)
{
return getPtrType(
ptrTypeOp,
valueType,
oldPtrType->getAccessQualifier(),
oldPtrType->getAddressSpace());
}

/// Get a GLSL output parameter group type
IRGLSLOutputParameterGroupType* getGLSLOutputParameterGroupType(IRType* elementType);
Expand Down Expand Up @@ -4545,6 +4555,11 @@ struct IRBuilder
addDecoration(value, kIROp_ForceUnrollDecoration, getIntValue(getIntType(), iters));
}

void addAlignmentDecoration(IRInst* value, IntegerLiteralValue alignment)
{
addDecoration(value, kIROp_AlignmentDecoration, getIntValue(getIntType(), alignment));
}

IRSemanticDecoration* addSemanticDecoration(
IRInst* value,
UnownedStringSlice const& text,
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-insts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,9 @@ local insts = {
{
AnyValueSize = { struct_name = "AnyValueSizeDecoration", operands = { { "sizeOperand", "IRIntLit" } } },
},
{
AlignmentDecoration = { operands= { { "alignmentOperand", "IRIntLit" } } },
},
{ SpecializeDecoration = {} },
{ SequentialIDDecoration = { operands = { { "sequentialIdOperand", "IRIntLit" } } } },
{ DynamicDispatchWitnessDecoration = {} },
Expand Down
10 changes: 10 additions & 0 deletions tests/spirv/aligned-load-store.slang
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//TEST:SIMPLE(filecheck=PTX): -target ptx -entry computeMain -stage compute
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// CHECK: OpLoad {{.*}} Aligned 8
Expand All @@ -6,11 +7,16 @@
// CHECK: OpLoad {{.*}} Aligned 16
// CHECK: OpStore {{.*}} Aligned 16

// PTX: ld.global.v2.f32
// PTX: ld.global.v4.f32
// PTX: ld.global.nc.v4.f32

uniform float4* data;

struct C { float4 v0; float4 v1; }
uniform C* data2;

uniform ImmutablePtr<C> data3;

[numthreads(1,1,1)]
void computeMain()
Expand All @@ -21,4 +27,8 @@ void computeMain()
var v1 = loadAligned<16>(data2);
v1.v0 += 1.0f;
storeAligned<16>(data2, v1);

var v2 = loadAligned<16>(data3);
v2.v0 += 1.0f;
storeAligned<16>(data2, v2);
}
Loading