Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ __intrinsic_op($(kIROp_AlignedAttr))
internal int __align_attr(int alignment);

__intrinsic_op($(kIROp_Load))
internal T __load_aligned<T>(T* ptr, int alignmentAttr);
internal T __load_aligned<T, Access access>(Ptr<T, access, AddressSpace.Device> ptr, int alignmentAttr);

__intrinsic_op($(kIROp_Store))
internal void __store_aligned<T>(T* ptr, T value, int alignmentAttr);
Expand All @@ -1443,7 +1443,7 @@ internal void __store_aligned<T>(T* ptr, T value, int alignmentAttr);
///
[__NoSideEffect]
[ForceInline]
T loadAligned<int alignment, T>(T* ptr)
T loadAligned<int alignment, T, Access access>(Ptr<T, access, AddressSpace.Device> ptr)
{
return __load_aligned(ptr, __align_attr(alignment));
}
Expand Down
10 changes: 10 additions & 0 deletions source/slang/slang-emit-cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ void CUDASourceEmitter::emitEntryPointAttributesImpl(
SLANG_UNUSED(entryPointDecor);
}

void CUDASourceEmitter::emitPostKeywordTypeAttributesImpl(IRInst* inst)
{
if (auto alignmentDecor = inst->findDecoration<IRAlignmentDecoration>())
{
m_writer->emit("__align__(");
m_writer->emit(getIntVal(alignmentDecor->getAlignmentOperand()));
m_writer->emit(") ");
}
}

void CUDASourceEmitter::emitFunctionPreambleImpl(IRInst* inst)
{
if (!inst)
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-emit-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class CUDASourceEmitter : public CPPSourceEmitter
virtual void emitEntryPointAttributesImpl(
IRFunc* irFunc,
IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) SLANG_OVERRIDE;

virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace)
SLANG_OVERRIDE;
Expand Down
4 changes: 3 additions & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,9 +1924,11 @@ Result linkAndOptimizeIR(

// If we are generating code for CUDA, we should translate all immutable buffer loads to
// using `__ldg` intrinsic for improved performance.
// For aligned loads, we will also create a wrapper struct type with an alignment decorator,
// and rewrite the load to load the wrapper struct.
if (isCUDATarget(targetRequest))
{
SLANG_PASS(lowerImmutableBufferLoadForCUDA, targetProgram);
SLANG_PASS(lowerImmutableOrAlignedBufferLoadForCUDA, targetProgram);
}

SLANG_PASS(performForceInlining);
Expand Down
110 changes: 103 additions & 7 deletions source/slang/slang-ir-cuda-immutable-load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,55 @@ struct LoadMethod
struct ImmutableBufferLoadLoweringContext : InstPassBase
{
Dictionary<IRType*, LoadMethod> loadFuncs;

struct AlignedTypeWrapperKey
{
IRType* innerType;
IRIntegerValue alignment;
HashCode getHashCode() const
{
return combineHash(Slang::getHashCode(innerType), Slang::getHashCode(alignment));
}
bool operator==(const AlignedTypeWrapperKey& other) const
{
return innerType == other.innerType && alignment == other.alignment;
}
};

Dictionary<AlignedTypeWrapperKey, IRType*> alignedWrapperTypes;

IRType* getOrCreateAlignedWrapper(IRType* innerType, IRIntegerValue alignment)
{
IRSizeAndAlignment naturalSizeAlignment;
if (SLANG_SUCCEEDED(getNaturalSizeAndAlignment(
targetProgram->getTargetReq(),
innerType,
&naturalSizeAlignment)) &&
naturalSizeAlignment.alignment >= alignment)
return innerType;

auto key = AlignedTypeWrapperKey{innerType, alignment};
if (auto* wrappedType = alignedWrapperTypes.tryGetValue(key))
return *wrappedType;

IRBuilder builder(innerType);
builder.setInsertAfter(innerType);

auto structType = builder.createStructType();
StringBuilder nameSb;
getTypeNameHint(nameSb, innerType);
nameSb << "_aligned" << alignment;
builder.addNameHintDecoration(structType, nameSb.getUnownedSlice());

auto fieldKey = builder.createStructKey();
builder.addNameHintDecoration(fieldKey, toSlice("val"));
builder.createStructField(structType, fieldKey, innerType);
builder.addAlignmentDecoration(structType, alignment);

alignedWrapperTypes[key] = structType;
return structType;
}

TargetProgram* targetProgram;

IRFunc* createLoadFunc(IRBuilder& builder, IRType* valueType, IRParam*& outParam)
Expand Down Expand Up @@ -300,16 +349,63 @@ struct ImmutableBufferLoadLoweringContext : InstPassBase
case kIROp_Load:
{
auto load = as<IRLoad>(inst);
if (isPointerToImmutableLocation(getRootAddr(load->getPtr())))
IRInst* loadedValue = load;

IRBuilder builder(load);
builder.setInsertBefore(load);
auto rootAddr = getRootAddr(load->getPtr());
auto alignmentAttr = load->findAttr<IRAlignedAttr>();
bool needUnwrap = false;
if (alignmentAttr)
{
auto wrappedType = getOrCreateAlignedWrapper(
load->getDataType(),
getIntVal(alignmentAttr->getAlignment()));
if (wrappedType != load->getDataType())
{
auto newPtr = builder.emitBitCast(
builder.getPtrType(
kIROp_PtrType,
wrappedType,
as<IRPtrTypeBase>(load->getPtr()->getDataType())),
load->getPtr());
builder.replaceOperand(load->getPtrOperand(), newPtr);
load->setFullType(wrappedType);
needUnwrap = true;
}
}
if (isPointerToImmutableLocation(rootAddr))
{
if (auto immutableLoad = emitImmutableLoad(builder, load->getPtr()))
{
loadedValue = immutableLoad;
}
}

if (needUnwrap)
{
IRBuilder builder(load);
builder.setInsertBefore(load);
if (auto newLoad = emitImmutableLoad(builder, load->getPtr()))
List<IRUse*> uses;
for (auto use = inst->firstUse; use; use = use->nextUse)
{
uses.add(use);
}
auto wrappedStruct = as<IRStructType>(loadedValue->getDataType());
SLANG_ASSERT(wrappedStruct);
auto firstField = wrappedStruct->getFields().getFirst();
SLANG_ASSERT(firstField);
auto fieldKey = firstField->getKey();
builder.setInsertAfter(loadedValue);
loadedValue = builder.emitFieldExtract(loadedValue, fieldKey);
for (auto use : uses)
{
load->replaceUsesWith(newLoad);
load->removeAndDeallocate();
builder.replaceOperand(use, loadedValue);
}
}
else if (loadedValue != inst)
{
inst->replaceUsesWith(loadedValue);
inst->removeAndDeallocate();
}
}
break;
case kIROp_StructuredBufferLoad:
Expand Down Expand Up @@ -365,7 +461,7 @@ struct ImmutableBufferLoadLoweringContext : InstPassBase
}
};

void lowerImmutableBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram)
void lowerImmutableOrAlignedBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram)
{
ImmutableBufferLoadLoweringContext context(module);
context.targetProgram = targetProgram;
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-cuda-immutable-load.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ for better performance.
struct IRModule;
class TargetProgram;

void lowerImmutableBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram);
void lowerImmutableOrAlignedBufferLoadForCUDA(IRModule* module, TargetProgram* targetProgram);

} // namespace Slang
1 change: 1 addition & 0 deletions source/slang/slang-ir-insts-stable-names.lua
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,5 @@ return {
["Type.OptionalNoneType"] = 725,
["ReinterpretOptional"] = 726,
["Type.DefaultPushConstantLayout"] = 727,
["Decoration.AlignmentDecoration"] = 728,
}
17 changes: 16 additions & 1 deletion source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ struct IRLoad : IRInst
{
FIDDLE(leafInst())
IRUse ptr;

IRUse* getPtrOperand() { return &ptr; }
IRInst* getPtr() { return ptr.get(); }
};

Expand Down Expand Up @@ -3344,6 +3344,16 @@ struct IRBuilder
oldPtrType->getAccessQualifier(),
oldPtrType->getAddressSpace());
}
// Copies the access-qualifier and address-space from oldPtrType. Use new pointer type opcode
// and value type.
IRPtrTypeBase* getPtrType(IROp ptrTypeOp, IRType* valueType, IRPtrTypeBase* oldPtrType)
{
return getPtrType(
ptrTypeOp,
valueType,
oldPtrType->getAccessQualifier(),
oldPtrType->getAddressSpace());
}

/// Get a GLSL output parameter group type
IRGLSLOutputParameterGroupType* getGLSLOutputParameterGroupType(IRType* elementType);
Expand Down Expand Up @@ -4545,6 +4555,11 @@ struct IRBuilder
addDecoration(value, kIROp_ForceUnrollDecoration, getIntValue(getIntType(), iters));
}

void addAlignmentDecoration(IRInst* value, IntegerLiteralValue alignment)
{
addDecoration(value, kIROp_AlignmentDecoration, getIntValue(getIntType(), alignment));
}

IRSemanticDecoration* addSemanticDecoration(
IRInst* value,
UnownedStringSlice const& text,
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-insts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,9 @@ local insts = {
{
AnyValueSize = { struct_name = "AnyValueSizeDecoration", operands = { { "sizeOperand", "IRIntLit" } } },
},
{
AlignmentDecoration = { operands= { { "alignmentOperand", "IRIntLit" } } },
},
{ SpecializeDecoration = {} },
{ SequentialIDDecoration = { operands = { { "sequentialIdOperand", "IRIntLit" } } } },
{ DynamicDispatchWitnessDecoration = {} },
Expand Down
10 changes: 10 additions & 0 deletions tests/spirv/aligned-load-store.slang
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//TEST:SIMPLE(filecheck=PTX): -target ptx -entry computeMain -stage compute
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// CHECK: OpLoad {{.*}} Aligned 8
Expand All @@ -6,11 +7,16 @@
// CHECK: OpLoad {{.*}} Aligned 16
// CHECK: OpStore {{.*}} Aligned 16

// PTX: ld.global.v2.f32
// PTX: ld.global.v4.f32
// PTX: ld.global.nc.v4.f32

uniform float4* data;

struct C { float4 v0; float4 v1; }
uniform C* data2;

uniform ImmutablePtr<C> data3;

[numthreads(1,1,1)]
void computeMain()
Expand All @@ -21,4 +27,8 @@ void computeMain()
var v1 = loadAligned<16>(data2);
v1.v0 += 1.0f;
storeAligned<16>(data2, v1);

var v2 = loadAligned<16>(data3);
v2.v0 += 1.0f;
storeAligned<16>(data2, v2);
}
Loading