-
Notifications
You must be signed in to change notification settings - Fork 429
Add custom derivatives for getOffset #10499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -77,6 +77,8 @@ public struct BindlessAddress<T> : IPointerLikeAddress<T> | |||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| [Differentiable] | ||||||||
| [ForwardDerivative(fwd_getOffset)] | ||||||||
| public This getOffset(int elements) | ||||||||
| { | ||||||||
| uint newBaseIndex = baseIndex + elements; | ||||||||
|
|
@@ -86,6 +88,13 @@ public struct BindlessAddress<T> : IPointerLikeAddress<T> | |||||||
| return address; | ||||||||
| } | ||||||||
|
|
||||||||
| static DifferentialPtrPair<This> fwd_getOffset(DifferentialPtrPair<This> self, int elements) | ||||||||
| { | ||||||||
| return DifferentialPtrPair<This>( | ||||||||
| self.p.getOffset(elements), | ||||||||
| self.d.getOffset(elements)); | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor style inconsistency: the Nit, not blocking. |
||||||||
| } | ||||||||
|
|
||||||||
| [ForceInline] | ||||||||
| internal uint4 readUint4<DstType, bool IsAligned, uint ActualBoundary>(int offsetIndex) | ||||||||
| where DstType : __BuiltinFloatingPointType | ||||||||
|
|
@@ -129,9 +138,18 @@ public struct PointerAddress<T> : IPointerLikeAddress<T> | |||||||
| set { ptr[index] = newValue; } | ||||||||
| } | ||||||||
|
|
||||||||
| [Differentiable] | ||||||||
| [ForwardDerivative(fwd_getOffset)] | ||||||||
| public This getOffset(int elements) | ||||||||
| { | ||||||||
| return This(ptr + elements); | ||||||||
| return no_diff(This(ptr + elements)); | ||||||||
| } | ||||||||
|
|
||||||||
| static DifferentialPtrPair<This> fwd_getOffset(DifferentialPtrPair<This> self, int elements) | ||||||||
| { | ||||||||
| return DifferentialPtrPair<This>( | ||||||||
| self.p.getOffset(elements), | ||||||||
| self.d.getOffset(elements)); | ||||||||
| } | ||||||||
|
|
||||||||
| [ForceInline] | ||||||||
|
|
@@ -229,13 +247,22 @@ public struct TorchTensorViewAddress<T> : IPointerLikeAddress<T> | |||||||
| } | ||||||||
|
|
||||||||
| [require(cuda)] | ||||||||
| [Differentiable] | ||||||||
| [ForwardDerivative(fwd_getOffset)] | ||||||||
| public This getOffset(int elements) | ||||||||
| { | ||||||||
| This result; | ||||||||
| result.inner = inner.getOffset(elements); | ||||||||
| return result; | ||||||||
| } | ||||||||
|
|
||||||||
| static DifferentialPtrPair<This> fwd_getOffset(DifferentialPtrPair<This> self, int elements) | ||||||||
| { | ||||||||
| return DifferentialPtrPair<This>( | ||||||||
| self.p.getOffset(elements), | ||||||||
| self.d.getOffset(elements)); | ||||||||
| } | ||||||||
|
Comment on lines
+259
to
+264
|
||||||||
|
|
||||||||
| [ForceInline] | ||||||||
| [require(cuda_glsl_hlsl_metal_spirv, sm_6_6)] | ||||||||
| public void atomicAdd(uint index, T value) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Missing The primal
Suggested change
Without it, a non-CUDA target could theoretically resolve this derivative function even though the primal is CUDA-only, producing a confusing error instead of a clean capability mismatch. |
||||||||
|
|
@@ -300,9 +327,18 @@ internal extension<T> Ptr<T, Access.ReadWrite, AddressSpace.Device> : IPointerLi | |||||||
| set { this[index] = newValue; } | ||||||||
| } | ||||||||
|
|
||||||||
| [Differentiable] | ||||||||
| [ForwardDerivative(fwd_getOffset)] | ||||||||
| internal This getOffset(int elements) | ||||||||
| { | ||||||||
| return This(this + elements); | ||||||||
| return no_diff(This(this + elements)); | ||||||||
| } | ||||||||
|
|
||||||||
| static DifferentialPtrPair<This> fwd_getOffset(DifferentialPtrPair<This> self, int elements) | ||||||||
| { | ||||||||
| return DifferentialPtrPair<This>( | ||||||||
| self.p + elements, | ||||||||
| self.d + elements); | ||||||||
| } | ||||||||
|
|
||||||||
| [require(hlsl, sm_6_6)] | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| // Unit test for ILayer with multi-layer FFN (backward pass via autodiff). | ||
| // Exercises getOffset inside the differentiable function so the custom | ||
| // derivative of getOffset is used to propagate DifferentialPtrPair. | ||
| // | ||
| //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -xslang -experimental-feature -output-using-type -emit-spirv-directly | ||
| //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-mtl -compute -shaderobj -output-using-type -xslang -experimental-feature | ||
| //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type -capability cuda_sm_7_0 -xslang -experimental-feature | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test — the math checks out and it exercises the key scenario (getOffset inside a differentiable function with a multi-layer FFN). One observation: this test only exercises Also, this test is missing the DX12 backend line that the corresponding forward test (
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Consider adding a This backward test only exercises Not blocking — the existing backward test suite provides reasonable coverage. |
||
|
|
||
| import slang.neural; | ||
|
|
||
| // Same 2-layer FFN as the forward test: | ||
| // Layer1 (2->2): W1 = [[2,-1],[0.5,3]], b1 = [1,-2] (6 params at offset 0) | ||
| // Layer2 (2->1): W2 = [[-2,4]], b2 = [0.5] (3 params at offset 6) | ||
| // Total: 9 parameters | ||
|
|
||
| //TEST_INPUT: set parametersFloat = ubuffer(data=[2.0 -1.0 0.5 3.0 1.0 -2.0 -2.0 4.0 0.5], stride=4) | ||
| RWStructuredBuffer<float> parametersFloat; | ||
|
|
||
| //TEST_INPUT: set params = ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4) | ||
| uniform RWStructuredBuffer<float>.Handle params; | ||
|
|
||
| //TEST_INPUT: set gradParams = ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4) | ||
| uniform RWStructuredBuffer<float>.Handle gradParams; | ||
|
|
||
| //TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=resultBuffer | ||
| RWStructuredBuffer<uint> resultBuffer; | ||
|
|
||
| typealias Address = BindlessAddress<float>; | ||
| typealias V2 = InlineVector<float, 2>; | ||
| typealias V1 = InlineVector<float, 1>; | ||
| typealias Act = IdentityActivation<float>; | ||
| typealias Layer1 = FFLayer<float, V2, V2, LinearLayout, Act, true>; | ||
| typealias Layer2 = FFLayer<float, V2, V1, LinearLayout, Act, true>; | ||
|
|
||
| bool approxEqual(float a, float b, float eps = 0.001) | ||
| { | ||
| return abs(a - b) < eps; | ||
| } | ||
|
|
||
| // All address arithmetic happens inside the differentiable function, | ||
| // so getOffset's custom derivative propagates the DifferentialPtrPair. | ||
| [Differentiable] | ||
| V1 computeFFN(Address baseAddr, V2 input, Layer1 layer1, Layer2 layer2) | ||
| { | ||
| let layer1Addr = baseAddr.getOffset(0); | ||
| let layer2Addr = baseAddr.getOffset(6); | ||
|
|
||
| let h = layer1.eval<Address>(input, layer1Addr); | ||
| return layer2.eval<Address>(h, layer2Addr); | ||
| } | ||
|
|
||
| [shader("compute")] | ||
| [numthreads(1, 1, 1)] | ||
| void computeMain() | ||
| { | ||
| for (int i = 0; i < 9; i++) | ||
| { | ||
| params[i] = parametersFloat[i]; | ||
| gradParams[i] = 0.0; | ||
| } | ||
|
|
||
| let baseAddr = Address(params); | ||
| let gradBaseAddr = Address(gradParams); | ||
|
|
||
| float[2] xArr = { 1.5, -2.0 }; | ||
| let x = V2(xArr); | ||
| let layer1 = Layer1(); | ||
| let layer2 = Layer2(); | ||
|
|
||
| var baseAddrPair = DifferentialPtrPair<Address>(baseAddr, gradBaseAddr); | ||
| var inputPair = diffPair(x); | ||
| let dOutput = V1(1.0); | ||
|
|
||
| bwd_diff(computeFFN)(baseAddrPair, inputPair, layer1, layer2, dOutput); | ||
|
|
||
| // Expected gradients (dL/dy = [1]): | ||
| // Forward: h = W1*x + b1 = [6, -7.25], y = W2*h + b2 = -40.5 | ||
| // | ||
| // dL/dW1 = outer(W2^T * dL/dy, x) = outer([-2, 4], [1.5, -2]) | ||
| // = [-3, 4, 6, -8] | ||
| // dL/db1 = W2^T * dL/dy = [-2, 4] | ||
| // dL/dW2 = outer(dL/dy, h) = [6, -7.25] | ||
| // dL/db2 = dL/dy = [1] | ||
| uint idx = 0; | ||
| resultBuffer[idx++] = approxEqual(gradParams[0], -3.0); // dW1[0,0] | ||
| resultBuffer[idx++] = approxEqual(gradParams[1], 4.0); // dW1[0,1] | ||
| resultBuffer[idx++] = approxEqual(gradParams[2], 6.0); // dW1[1,0] | ||
| resultBuffer[idx++] = approxEqual(gradParams[3], -8.0); // dW1[1,1] | ||
| resultBuffer[idx++] = approxEqual(gradParams[4], -2.0); // db1[0] | ||
| resultBuffer[idx++] = approxEqual(gradParams[5], 4.0); // db1[1] | ||
| resultBuffer[idx++] = approxEqual(gradParams[6], 6.0); // dW2[0,0] | ||
| resultBuffer[idx++] = approxEqual(gradParams[7], -7.25); // dW2[0,1] | ||
| resultBuffer[idx++] = approxEqual(gradParams[8], 1.0); // db2 | ||
|
|
||
| // BUFFER: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| // BUFFER-NEXT: 1 | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
In current Slang autodiff documentation, can a method marked[Differentiable]legally return a type that implementsIDifferentiablePtrType, or are such types restricted to input parameters only?💡 Result:
In Slang’s current autodiff docs, a function/method marked
[Differentiable]may not return a type that implementsIDifferentiablePtrType. Types implementingIDifferentiablePtrTypeare restricted to input (in) parameters only (i.e., they cannot be used asout,inout, or return types), and attempting to do so is a compile-time error. [1]getOffset()methods cannot be marked[Differentiable]because they returnIDifferentiablePtrType.In Slang's autodiff system, types implementing
IDifferentiablePtrTypeare restricted to input parameters only and cannot be returned from[Differentiable]methods—this is a compile-time error per the documented rules. The current annotations on these methods violate that restriction.Also applies to: 141–143, 250–252, 330–332