Skip to content

Commit 38b5b83

Browse files
Fix issue with custom derivatives that mutate differentiable pointer types (#10464)
Fixes: #10195 Enables users to write custom differentiable pointer mutation (like `getOffset`) and use forward-mode custom derivatives to control how the derivative counterpart is mutated. Added a diagnostic to make sure this cannot directly mix with differentiable value types (i.e. the same function cannot mutate pointer parameters and also have value parameters that require derivatives). In the future, it may make more sense to separate such custom derivatives into their own class as they aren't really forward-mode derivatives, but for now, this should help unblock some users.
1 parent 99e2a4f commit 38b5b83

File tree

4 files changed

+180
-1
lines changed

4 files changed

+180
-1
lines changed

source/slang/slang-diagnostics.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,6 +1799,13 @@ err(
17991799
span { loc = "expr:Expr", message = "generic constraint for parameter '~param:Type' references type parameter '~referenced:Decl' before it is declared" }
18001800
)
18011801

1802+
err(
1803+
"cannot-mix-differentiable-value-and-ptr-outputs",
1804+
30118,
1805+
"cannot mix differentiable value types with differentiable pointer outputs",
1806+
span { loc = "location", message = "function has both IDifferentiable value types and IDifferentiablePtrType outputs, which is not currently supported. Please split the function so that differentiable value parameters and pointer differentiable outputs are in separate functions." }
1807+
)
1808+
18021809

18031810
-- Load semantic checking diagnostics (part 3) - Include, Visibility, and Capability
18041811
-- (inlined from slang-diagnostics-semantic-checking-3.lua)

source/slang/slang-ir-autodiff-fwd.cpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,10 +924,75 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
924924
diffReturnType = primalReturnType;
925925
}
926926

927+
// Determine if this call should use "pointer pair method" mode.
928+
// If any output (out/inout/ref parameter, or return value) is a differentiable pointer type
929+
// (conforms to IDifferentiablePtrType), we should not mark the call as mixed-differential
930+
// so that the transposition logic won't try to reverse it (doesn't make sense for
931+
// differential pointers).
932+
//
933+
// If we detect a differentiable pointer output but there are also any differentiable
934+
// value types (either inputs or outputs, i.e. conforming to IDifferentiable), we
935+
// emit a diagnostic since we cannot transpose such mixed functions.
936+
//
937+
bool hasDiffPtrOutput = false;
938+
bool hasDiffValueType = false;
939+
940+
// Check all parameters for differentiable value types (inputs and outputs)
941+
// and output parameters for differentiable pointer types.
942+
for (UIndex ii = 0; ii < calleeType->getParamCount(); ii++)
943+
{
944+
auto paramType = calleeType->getParamType(ii);
945+
bool isOutput = as<IROutParamTypeBase>(paramType) || as<IRRefParamType>(paramType);
946+
947+
// For outputs (out/inout/ref), get the value type from the pointer.
948+
// For inputs, use the param type directly.
949+
IRType* valueType;
950+
if (isOutput)
951+
valueType = (IRType*)as<IRPtrTypeBase>(paramType)->getValueType();
952+
else
953+
valueType = paramType;
954+
955+
// Only outputs can contribute differentiable pointer types.
956+
if (isOutput && differentiableTypeConformanceContext.isDifferentiablePtrType(valueType))
957+
hasDiffPtrOutput = true;
958+
959+
// Any parameter (input or output) can contribute differentiable value types.
960+
if (differentiableTypeConformanceContext.isDifferentiableValueType(valueType))
961+
hasDiffValueType = true;
962+
}
963+
964+
// Check return type.
965+
if (auto retType = calleeType->getResultType())
966+
{
967+
if (differentiableTypeConformanceContext.isDifferentiablePtrType((IRType*)retType))
968+
hasDiffPtrOutput = true;
969+
if (differentiableTypeConformanceContext.isDifferentiableValueType((IRType*)retType))
970+
hasDiffValueType = true;
971+
}
972+
973+
// If there are pointer differentiable outputs AND any differentiable value types
974+
// (inputs or outputs), emit a diagnostic.
975+
if (hasDiffPtrOutput && hasDiffValueType)
976+
{
977+
getSink()->diagnose(Diagnostics::CannotMixDifferentiableValueAndPtrOutputs{
978+
.location = origCall->sourceLoc});
979+
980+
// We'll continue transcribing instead of erroring out right away since we
981+
// can still generate code (it just won't be correct). The diagnostic
982+
// will still disallow any final code from being generated.
983+
}
984+
985+
bool usePointerPairMode = hasDiffPtrOutput;
986+
927987
auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args);
928988
placeholderCall->removeAndDeallocate();
929989

930-
argBuilder.markInstAsMixedDifferential(callInst, diffReturnType);
990+
// In pointer pair mode, do not mark as mixed differential since
991+
// the transposition logic should not try to reverse this call.
992+
if (!usePointerPairMode)
993+
{
994+
argBuilder.markInstAsMixedDifferential(callInst, diffReturnType);
995+
}
931996
argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee);
932997

933998
*builder = afterBuilder;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//DIAGNOSTIC_TEST:SIMPLE(diag=CHECK,non-exhaustive):-target spirv
2+
3+
// Test that mixing IDifferentiable value types with IDifferentiablePtrType outputs
4+
// in a differentiable function produces an error diagnostic.
5+
6+
struct MyPtr : IDifferentiablePtrType
7+
{
8+
typealias Differential = MyPtr;
9+
uint offset;
10+
}
11+
12+
// This function has both an IDifferentiablePtrType output (inout MyPtr)
13+
// and an IDifferentiable value input (float x), which is not supported.
14+
[Differentiable]
15+
[ForwardDerivative(fwd_mixed)]
16+
float mixedFunc(inout MyPtr ptr, float x)
17+
{
18+
return x * ptr.offset;
19+
}
20+
21+
DifferentialPair<float> fwd_mixed(inout DifferentialPtrPair<MyPtr> ptr, DifferentialPair<float> x)
22+
{
23+
return diffPair(x.p * ptr.p.offset, x.d * ptr.p.offset);
24+
}
25+
26+
[Differentiable]
27+
float caller(float x)
28+
{
29+
var ptr = MyPtr(2.0);
30+
let r = mixedFunc(ptr, x);
31+
//CHECK: ^ cannot mix differentiable value types with differentiable pointer outputs
32+
return r;
33+
}
34+
35+
[shader("compute")]
36+
[numthreads(1, 1, 1)]
37+
void computeMain()
38+
{
39+
fwd_diff(caller)(diffPair(1.0, 1.0));
40+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUFFER): -vk -shaderobj -output-using-type -emit-spirv-directly
2+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type
3+
4+
struct MyAddress : IDifferentiablePtrType
5+
{
6+
typealias Differential = MyAddress;
7+
8+
RWStructuredBuffer<float>.Handle handle;
9+
uint baseIndex;
10+
11+
__init(RWStructuredBuffer<float>.Handle h)
12+
{
13+
handle = h;
14+
baseIndex = 0;
15+
}
16+
17+
__subscript(uint index) -> float
18+
{
19+
[nonmutating]
20+
get { return handle[baseIndex + index]; }
21+
}
22+
23+
[Differentiable]
24+
[ForwardDerivative(fwd_getOffset)]
25+
This getOffset(int elements)
26+
{
27+
This addr = This(handle);
28+
addr.baseIndex = baseIndex + uint(elements);
29+
return addr;
30+
}
31+
32+
static DifferentialPtrPair<This> fwd_getOffset(DifferentialPtrPair<This> self, int elements)
33+
{
34+
return DifferentialPtrPair<This>(
35+
self.p.getOffset(elements),
36+
self.d.getOffset(elements));
37+
}
38+
}
39+
40+
//TEST_INPUT: set buf = ubuffer(data=[1.0 2.0 3.0 4.0], stride=4)
41+
uniform RWStructuredBuffer<float>.Handle buf;
42+
43+
//TEST_INPUT: set gradBuf = ubuffer(data=[0.0 0.0 0.0 0.0], stride=4)
44+
uniform RWStructuredBuffer<float>.Handle gradBuf;
45+
46+
//TEST_INPUT: ubuffer(data=[0], stride=4):out,name=result
47+
RWStructuredBuffer<uint> result;
48+
49+
[Differentiable]
50+
float compute(MyAddress addr, float x)
51+
{
52+
let addr2 = addr.getOffset(2);
53+
return x * no_diff addr2[0];
54+
}
55+
56+
[shader("compute")]
57+
[numthreads(1, 1, 1)]
58+
void computeMain()
59+
{
60+
let addr = MyAddress(buf);
61+
let gradAddr = MyAddress(gradBuf);
62+
var addrPair = DifferentialPtrPair<MyAddress>(addr, gradAddr);
63+
var xPair = diffPair(1.0f);
64+
bwd_diff(compute)(addrPair, xPair, 1.0f);
65+
result[0] = xPair.d == 3.0 ? 1 : 0;
66+
// BUFFER: 1
67+
}

0 commit comments

Comments
 (0)