Skip to content

Commit 8dc2705

Browse files
csyongheCopilot
andauthored
Don't rewrite entrypoint in to borrow during ir lowering. (#9869)
If the user compile a shader via `findAndCheckEntrypoint` API on a function that does not have a `[shader()]` attribute, we could be generating invalid spirv code. There is logic in lower-to-ir that silently changes the parameter passing convention of entrypoint in to borrow in, which is fundamentally conflicting with the compilation model. When we load a module, we compile every function down to IR. If we want to lower entrypoint differently, we must know some func is entrypoint during lower-to-ir of a module. But since this info is not provided at `loadModule` time, we fundamentally cannot do this kind of transform when lowering a module to IR. The right thing to do is to implement an IR pass that turns these in parameters into borrow in once we know a function is entrypoint. This PR deletes the special logic in lower-to-ir that does the problematic rewrite, and use a dedicated pass derived from the existing `transformParamsToConstRef` pass to do the rewrite after linking, where entrypoints are always specified. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent a9ba4fc commit 8dc2705

File tree

6 files changed

+149
-115
lines changed

6 files changed

+149
-115
lines changed

source/slang/slang-emit.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,8 @@ Result linkAndOptimizeIR(
746746
if (!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO)
747747
SLANG_PASS(lowerGLSLShaderStorageBufferObjectsToStructuredBuffers, sink);
748748

749+
SLANG_PASS(translateEntryPointInParamToBorrow, sink);
750+
749751
if (requiredLoweringPassSet.globalVaryingVar)
750752
SLANG_PASS(translateGlobalVaryingVar, codeGenContext);
751753

source/slang/slang-ir-transform-params-to-constref.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct TransformParamsToConstRefContext
2020
}
2121

2222
// Check if a type should be transformed (struct, array, or other composite types)
23-
bool shouldTransformParam(IRParam* param)
23+
virtual bool shouldTransformParam(IRParam* param)
2424
{
2525
auto type = param->getDataType();
2626
if (!type)
@@ -97,6 +97,7 @@ struct TransformParamsToConstRefContext
9797
auto fieldExtract = as<IRFieldExtract>(userInst);
9898
builder.setInsertBefore(fieldExtract);
9999
auto fieldAddr = builder.emitFieldAddress(
100+
builder.getPtrType(userInst->getDataType()),
100101
loadInst->getPtr(),
101102
fieldExtract->getField());
102103
auto loadFieldAddr = as<IRLoad>(builder.emitLoad(fieldAddr));
@@ -120,6 +121,7 @@ struct TransformParamsToConstRefContext
120121
auto getElement = as<IRGetElement>(userInst);
121122
builder.setInsertBefore(getElement);
122123
auto getElementPtr = builder.emitElementAddress(
124+
builder.getPtrType(userInst->getDataType()),
123125
loadInst->getPtr(),
124126
getElement->getIndex());
125127
auto loadElementPtr = as<IRLoad>(builder.emitLoad(getElementPtr));
@@ -253,7 +255,7 @@ struct TransformParamsToConstRefContext
253255
}
254256

255257
// Check if function should be excluded from transformation
256-
bool shouldProcessFunction(IRFunc* func)
258+
virtual bool shouldProcessFunction(IRFunc* func)
257259
{
258260
// Skip functions without definitions
259261
if (!func->isDefinition())
@@ -421,4 +423,58 @@ SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink)
421423
return context.processModule();
422424
}
423425

426+
struct EntryPointInParamToBorrowContext : public TransformParamsToConstRefContext
427+
{
428+
EntryPointInParamToBorrowContext(IRModule* module, DiagnosticSink* sink)
429+
: TransformParamsToConstRefContext(module, sink)
430+
{
431+
}
432+
virtual bool shouldProcessFunction(IRFunc* func) override
433+
{
434+
if (!func->isDefinition())
435+
return false;
436+
if (func->findDecoration<IREntryPointDecoration>() != nullptr)
437+
return true;
438+
return false;
439+
}
440+
virtual bool shouldTransformParam(IRParam* param) override
441+
{
442+
auto type = param->getDataType();
443+
if (as<IRPointerLikeType>(type))
444+
return false;
445+
if (as<IRPtrTypeBase>(type))
446+
return false;
447+
if (as<IRMeshOutputType>(type))
448+
return false;
449+
if (as<IRHLSLPatchType>(type))
450+
return false;
451+
452+
// Skip uniform parameters.
453+
// We expect all entry-point parameters to have layout information,
454+
// but we will be defensive and skip parameters without the required
455+
// information when we are in a release build.
456+
//
457+
auto layoutDecoration = param->findDecoration<IRLayoutDecoration>();
458+
SLANG_ASSERT(layoutDecoration);
459+
if (!layoutDecoration)
460+
return false;
461+
auto paramLayout = as<IRVarLayout>(layoutDecoration->getLayout());
462+
SLANG_ASSERT(paramLayout);
463+
if (!paramLayout)
464+
return false;
465+
if (!isVaryingParameter(paramLayout))
466+
return false;
467+
468+
// If we reach here, we are dealing with a varying in parameter.
469+
// We need to rewrite it to be a `borrow in` parameter.
470+
return true;
471+
}
472+
};
473+
474+
SlangResult translateEntryPointInParamToBorrow(IRModule* module, DiagnosticSink* sink)
475+
{
476+
EntryPointInParamToBorrowContext context(module, sink);
477+
return context.processModule();
478+
}
479+
424480
} // namespace Slang

source/slang/slang-ir-transform-params-to-constref.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ class DiagnosticSink;
99

1010
SlangResult transformParamsToConstRef(IRModule* module, DiagnosticSink* sink);
1111

12+
SlangResult translateEntryPointInParamToBorrow(IRModule* module, DiagnosticSink* sink);
13+
1214
} // namespace Slang

source/slang/slang-lower-to-ir.cpp

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3552,104 +3552,6 @@ void maybeAddReturnDestinationParam(ParameterLists* ioParameterLists, Type* resu
35523552
}
35533553
}
35543554

3555-
/// Does the given `declRef` appear to be a declaration of an entry point?
3556-
///
3557-
/// This function does a best-effort job of detecting whether something is
3558-
/// a shader entry point, but it cannot answer the question definitively,
3559-
/// because the compilation model of Slang allows functions to be tagged
3560-
/// as entry points separately from their declaration in the AST.
3561-
///
3562-
bool doesDeclAppearToBeAnEntryPoint(DeclRef<CallableDecl> const& declRef)
3563-
{
3564-
auto decl = declRef.getDecl();
3565-
if (decl->hasModifier<EntryPointAttribute>())
3566-
return true;
3567-
if (decl->hasModifier<NumThreadsAttribute>())
3568-
return true;
3569-
return false;
3570-
}
3571-
3572-
/// Does a parameter appear to be a declaration of an entry-point varying input.
3573-
///
3574-
/// The `paramInfo` represents the parameter being inspected, and
3575-
/// the `funcDeclRef` should represent the outer declaration that
3576-
/// might or might not be an entry point.
3577-
///
3578-
/// This function is only able to do a best-effort check for what is an
3579-
/// entry point (see `doesDeclAppearToBeAnEntryPoint()`), and cannot be
3580-
/// fully robust in detection.
3581-
///
3582-
bool doesParamAppearToBeAnEntryPointVaryingInput(
3583-
IRLoweringParameterInfo const& paramInfo,
3584-
DeclRef<CallableDecl> const& funcDeclRef)
3585-
{
3586-
// If the outer declaration doesn't appear to be an entry
3587-
// point, then it seems this isn't an entry point parameter
3588-
// at all, much less an input.
3589-
//
3590-
if (!doesDeclAppearToBeAnEntryPoint(funcDeclRef))
3591-
return false;
3592-
3593-
// We are only intereste in parameters that would otherwise
3594-
// be lowered to just use `in`.
3595-
//
3596-
if (paramInfo.actualParamPassingModeToUse != ParamPassingMode::In)
3597-
return false;
3598-
3599-
// We are only concerned with varying parameters, so `uniform`
3600-
// parameters aren't relevant.
3601-
//
3602-
if (paramInfo.decl->findModifier<HLSLUniformModifier>())
3603-
return false;
3604-
3605-
// Certain types conceptually represent uniform parameters even
3606-
// if they aren't declared with `uniform`, so we also want to
3607-
// ignore those.
3608-
//
3609-
// TODO: It would be best for logic like this to be factored into
3610-
// a shared subroutine somewhere, so that we don't have multiple
3611-
// places in the code doing ad hoc checks for a particular list
3612-
// of types, that might change over time.
3613-
//
3614-
if (as<HLSLPatchType>(paramInfo.type))
3615-
return false;
3616-
3617-
if (as<MeshOutputType>(paramInfo.type))
3618-
return false;
3619-
3620-
return true;
3621-
}
3622-
3623-
/// Modify the parameter passing mode for the given parameter, if it can
3624-
/// be detected that it seems to be used as a varying input to an entry point.
3625-
///
3626-
/// If a varying `in` parameter is detected, its actual parameter-passing
3627-
/// mode will be changed to `borrow in`.
3628-
///
3629-
/// This function is only able to do a best-effort check for what is an
3630-
/// entry point (see `doesDeclAppearToBeAnEntryPoint()`), and cannot be
3631-
/// fully robust in detection.
3632-
///
3633-
void maybeModifyParamPassingModeForDetectedEntryPointVaryingInput(
3634-
IRLoweringParameterInfo& ioParamInfo,
3635-
DeclRef<CallableDecl> const& funcDeclRef)
3636-
{
3637-
// If we cannot detect that this seems to be a varying `in`
3638-
// parameter of an entry point, then we'll just skip it.
3639-
//
3640-
// There's no way for this code to be sure it isn't skipping
3641-
// over a function it shouldn't, but there's nothign we can
3642-
// do about it from here.
3643-
//
3644-
if (!doesParamAppearToBeAnEntryPointVaryingInput(ioParamInfo, funcDeclRef))
3645-
return;
3646-
3647-
// We basically just want to change the parameter from `in`
3648-
// to `borrow in`, so that it is an immutable by-reference parameter.
3649-
//
3650-
ioParamInfo.actualParamPassingModeToUse = ParamPassingMode::BorrowIn;
3651-
}
3652-
36533555
//
36543556
// And here is our function that will do the recursive walk:
36553557
void collectParameterLists(
@@ -3894,20 +3796,6 @@ void collectParameterLists(
38943796
{
38953797
auto paramInfo = getParameterInfo(context, paramDeclRef);
38963798

3897-
// One unfortunate wrinkle that arises is that all the downstream
3898-
// logic in the Slang IR *really* wants the varying input parameters
3899-
// to a shader entry point to use `ParamPassingMode::BorrowIn`,
3900-
// so that they don't force copying, but syntactically such parameters
3901-
// are conventionally declared using `in` (meaning `ParamPassingMode::In`).
3902-
//
3903-
// We include logic here to switch up the parameter-passing mode
3904-
// in the case where we can statically detect that something is
3905-
// being compiled as an entry point. Note, however, that this logic
3906-
// is inherently fragile, since not every entry point that a user specifies
3907-
// is guaranteed to have had a `[shader(...)]` attribute on it.
3908-
//
3909-
maybeModifyParamPassingModeForDetectedEntryPointVaryingInput(paramInfo, callableDeclRef);
3910-
39113799
ioParameterLists->params.add(paramInfo);
39123800
}
39133801

tests/spirv/debug-return-types.slang

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ float4 main(VSOutput input) : SV_TARGET {
1818
}
1919

2020
// CHECK: %[[VECTOR:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeVector
21-
// CHECK: %[[VSOUTPUT:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypePointer
21+
// CHECK: %[[VSOUTPUT:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeComposite
2222
// CHECK: {{.*}} = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeFunction %{{[a-zA-Z0-9_]+}} %[[VECTOR]] %[[VSOUTPUT]]
2323
// CHECK: %[[MATRIX:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeMatrix
2424
// CHECK: {{.*}} = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeFunction %{{[a-zA-Z0-9_]+}} %[[MATRIX]] %[[MATRIX]]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// unit-test-entrypoint-compile.cpp
2+
3+
#include "../../source/core/slang-io.h"
4+
#include "../../source/core/slang-process.h"
5+
#include "slang-com-ptr.h"
6+
#include "slang.h"
7+
#include "unit-test/slang-unit-test.h"
8+
9+
#include <stdio.h>
10+
#include <stdlib.h>
11+
12+
using namespace Slang;
13+
14+
SLANG_UNIT_TEST(entryPointCompile)
15+
{
16+
ComPtr<slang::IGlobalSession> globalSession;
17+
SLANG_CHECK_ABORT(
18+
slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
19+
slang::TargetDesc targetDesc = {};
20+
// Request SPIR-V disassembly so we can check the content.
21+
targetDesc.format = SLANG_SPIRV_ASM;
22+
targetDesc.profile = globalSession->findProfile("spirv_1_5");
23+
slang::SessionDesc sessionDesc = {};
24+
sessionDesc.targetCount = 1;
25+
sessionDesc.targets = &targetDesc;
26+
27+
List<slang::CompilerOptionEntry> optionEntries;
28+
{
29+
slang::CompilerOptionEntry entry;
30+
entry.name = slang::CompilerOptionName::EnableEffectAnnotations;
31+
entry.value.kind = slang::CompilerOptionValueKind::Int;
32+
entry.value.intValue0 = 1;
33+
optionEntries.add(entry);
34+
}
35+
36+
sessionDesc.compilerOptionEntries = optionEntries.getBuffer();
37+
sessionDesc.compilerOptionEntryCount = (uint32_t)optionEntries.getCount();
38+
39+
ComPtr<slang::ISession> session;
40+
SLANG_CHECK_ABORT(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
41+
42+
ComPtr<slang::IBlob> diagnosticBlob;
43+
String userSourceBody = R"(
44+
struct VS_INPUT
45+
{
46+
float3 vPositionOs : POSITION;
47+
uint nInstanceIdx : TEXCOORD13;
48+
uint instanceId : SV_InstanceID;
49+
}
50+
51+
struct PS_INPUT
52+
{
53+
float4 vPositionPs : SV_Position ;
54+
}
55+
56+
StructuredBuffer<float3> gBuffer;
57+
58+
PS_INPUT vsMain(VS_INPUT i)
59+
{
60+
PS_INPUT o = {};
61+
float3 v = gBuffer[i.instanceId].xyz ;
62+
o.vPositionPs = float4(v, 1.0f);
63+
return o ;
64+
})";
65+
auto srcBlob = StringBlob::moveCreate(_Move(userSourceBody));
66+
auto module = session->loadModuleFromSource("m", "m.slang", srcBlob, diagnosticBlob.writeRef());
67+
SLANG_CHECK_ABORT(module != nullptr);
68+
69+
ComPtr<slang::IEntryPoint> entryPoint;
70+
module->findAndCheckEntryPoint(
71+
"vsMain",
72+
SLANG_STAGE_VERTEX,
73+
entryPoint.writeRef(),
74+
diagnosticBlob.writeRef());
75+
SLANG_CHECK_ABORT(entryPoint != nullptr);
76+
77+
ComPtr<slang::IComponentType> linkedProgram;
78+
entryPoint->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
79+
SLANG_CHECK_ABORT(linkedProgram != nullptr);
80+
81+
ComPtr<slang::IBlob> code;
82+
linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
83+
SLANG_CHECK_ABORT(code != nullptr);
84+
85+
SLANG_CHECK_ABORT(code->getBufferSize() != 0);
86+
}

0 commit comments

Comments
 (0)