Skip to content

Commit 73b2fd1

Browse files
authored
[WebGPU] Fix shader key for ScatterProgram (#7932)
* Add uniforms * Update webgpu_program.ts * Update webgpu_program.ts * Revert "Add uniforms" This reverts commit 645802f. * Revert "Update webgpu_program.ts" This reverts commit 58ea96f. * Revert "Update webgpu_program.ts" This reverts commit 32386ac. * Add key to scatter webgpu program
1 parent a7bec12 commit 73b2fd1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tfjs-backend-webgpu/src/scatter_webgpu.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ export class ScatterProgram implements WebGPUProgram {
4848
this.dispatch =
4949
computeDispatch(this.dispatchLayout, flattenXShape, this.workgroupSize);
5050
this.sliceDimGreaterThanOne = sliceDim > 1;
51-
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
52-
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
51+
this.shaderKey =
52+
`scatter_${indicesRank}_${updatesRank}_${this.sliceDimGreaterThanOne}_${
53+
outputDtype}_${sumDupeIndices}_${strides.length}`;
5354
const stridesType = getCoordsDataType(strides.length);
5455
this.uniforms =
5556
`sliceDim : i32, strides: ${stridesType}, updatesSize: i32,`;

0 commit comments

Comments
 (0)