Skip to content

Commit 3116a10

Browse files
sklumpyu10055
andauthored
Infer dtype in wasm fill (#7953)
Co-authored-by: Ping Yu <[email protected]>
1 parent 7880437 commit 3116a10

File tree

1 file changed

+6
-2
lines changed
  • tfjs-backend-wasm/src/kernels

1 file changed

+6
-2
lines changed

tfjs-backend-wasm/src/kernels/Fill.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
* =============================================================================
1616
*/
1717

18-
import {KernelConfig, KernelFunc} from '@tensorflow/tfjs-core';
18+
import {KernelConfig, KernelFunc, util} from '@tensorflow/tfjs-core';
1919
import {Fill, FillAttrs} from '@tensorflow/tfjs-core';
2020

2121
import {BackendWasm} from '../backend_wasm';
2222

2323
export function fill(args: {attrs: FillAttrs, backend: BackendWasm}) {
24-
const {attrs: {shape, value, dtype}, backend} = args;
24+
const {attrs: {shape, value}, backend} = args;
25+
let {attrs: {dtype}} = args;
26+
27+
dtype = dtype || util.inferDtype(value);
28+
2529
const out = backend.makeOutput(shape, dtype);
2630
const outVals = backend.typedArrayFromHeap(out);
2731
outVals.fill(value as number);

0 commit comments

Comments
 (0)