Skip to content

Commit cfc3001

Browse files
authored
[webgl]Add functions for parallel compilation (#5826)
1 parent b73ea50 commit cfc3001

File tree

7 files changed

+258
-31
lines changed

7 files changed

+258
-31
lines changed

tfjs-backend-webgl/src/backend_webgl.ts

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import './flags_webgl';
2020

2121
import * as tf from '@tensorflow/tfjs-core';
22-
import {backend_util, BackendValues, buffer, DataId, DataStorage, DataToGPUWebGLOption, DataType, DataValues, engine, env, GPUData, kernel_impls, KernelBackend, MemoryInfo, NumericDataType, Rank, RecursiveArray, scalar, ShapeMap, Tensor, Tensor2D, TensorBuffer, TensorInfo, tidy, TimingInfo, TypedArray, util} from '@tensorflow/tfjs-core';
23-
22+
import {backend_util, BackendValues, buffer, DataId, DataStorage, DataToGPUWebGLOption, DataType, DataValues, engine, env, GPUData, kernel_impls, KernelBackend, MemoryInfo, nextFrame, NumericDataType, Rank, RecursiveArray, scalar, ShapeMap, Tensor, Tensor2D, TensorBuffer, TensorInfo, tidy, TimingInfo, TypedArray, util} from '@tensorflow/tfjs-core';
2423
import {getWebGLContext} from './canvas_util';
2524
import {DecodeMatrixProgram} from './decode_matrix_gpu';
2625
import {DecodeMatrixPackedProgram} from './decode_matrix_packed_gpu';
@@ -30,7 +29,7 @@ import {EncodeMatrixProgram} from './encode_matrix_gpu';
3029
import {EncodeMatrixPackedProgram} from './encode_matrix_packed_gpu';
3130
import {GPGPUContext} from './gpgpu_context';
3231
import * as gpgpu_math from './gpgpu_math';
33-
import {GPGPUBinary, GPGPUProgram, TensorData} from './gpgpu_math';
32+
import {getUniformLocations, GPGPUBinary, GPGPUProgram, TensorData} from './gpgpu_math';
3433
import {simpleAbsImplCPU} from './kernel_utils/shared';
3534
import {PackProgram} from './pack_gpu';
3635
import {ReshapePackedProgram} from './reshape_packed_gpu';
@@ -549,15 +548,16 @@ export class MathBackendWebGL extends KernelBackend {
549548
};
550549

551550
return (async () => {
552-
if (env()
553-
.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
551+
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') >
552+
0) {
554553
const kernelMs = await Promise.all(flattenedActiveTimerQueries);
555554

556555
res['kernelMs'] = util.sum(kernelMs);
557556
res['getExtraProfileInfo'] = () =>
558-
kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d}))
559-
.map(d => `${d.name}: ${d.ms}`)
560-
.join(', ');
557+
kernelMs
558+
.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d}))
559+
.map(d => `${d.name}: ${d.ms}`)
560+
.join(', ');
561561
} else {
562562
res['kernelMs'] = {
563563
error: 'WebGL query timers are not supported in this environment.'
@@ -949,8 +949,10 @@ export class MathBackendWebGL extends KernelBackend {
949949
query = this.startTimer();
950950
}
951951

952-
gpgpu_math.runProgram(
953-
this.gpgpu, binary, inputsData, outputData, customUniformValues);
952+
if (!env().get('ENGINE_COMPILE_ONLY')) {
953+
gpgpu_math.runProgram(
954+
this.gpgpu, binary, inputsData, outputData, customUniformValues);
955+
}
954956

955957
dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
956958

@@ -1130,16 +1132,21 @@ export class MathBackendWebGL extends KernelBackend {
11301132

11311133
// Have the original texture assume the identity of the encoded output.
11321134
const outputTexData = this.texData.get(encodedOutputTarget.dataId);
1133-
texData.texture = outputTexData.texture;
11341135
texData.texShape = outputTexData.texShape;
11351136
texData.isPacked = outputTexData.isPacked;
11361137
texData.usage = outputTexData.usage;
11371138

1139+
if (!env().get('ENGINE_COMPILE_ONLY')) {
1140+
texData.texture = outputTexData.texture;
1141+
// Once uploaded, don't store the values on cpu.
1142+
texData.values = null;
1143+
this.texData.delete(encodedOutputTarget.dataId);
1144+
} else {
1145+
this.disposeData(encodedOutputTarget.dataId);
1146+
}
1147+
11381148
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
1139-
this.texData.delete(encodedOutputTarget.dataId);
11401149

1141-
// Once uploaded, don't store the values on cpu.
1142-
texData.values = null;
11431150
if (shouldTimeProgram) {
11441151
this.uploadWaitMs += util.now() - start;
11451152
}
@@ -1180,6 +1187,87 @@ export class MathBackendWebGL extends KernelBackend {
11801187
private computeBytes(shape: [number, number], dtype: DataType) {
11811188
return shape[0] * shape[1] * util.bytesPerElement(dtype);
11821189
}
1190+
1191+
checkCompileCompletion() {
1192+
for (const [, binary] of Object.entries(this.binaryCache)) {
1193+
this.checkCompletion_(binary);
1194+
}
1195+
}
1196+
1197+
async checkCompileCompletionAsync(): Promise<boolean[]> {
1198+
const ps = [];
1199+
if (this.gpgpu.parallelCompilationExtension) {
1200+
for (const [, binary] of Object.entries(this.binaryCache)) {
1201+
ps.push(this.checkCompletionAsync_(binary));
1202+
}
1203+
return Promise.all(ps);
1204+
} else {
1205+
for (const [, binary] of Object.entries(this.binaryCache)) {
1206+
const p: Promise<boolean> = new Promise((resolve) => {
1207+
try {
1208+
this.checkCompletion_(binary);
1209+
resolve(true);
1210+
} catch (error) {
1211+
throw error;
1212+
}
1213+
});
1214+
ps.push(p);
1215+
}
1216+
return Promise.all(ps);
1217+
}
1218+
}
1219+
1220+
private async checkCompletionAsync_(binary: GPGPUBinary): Promise<boolean> {
1221+
if (this.gpgpu.gl.getProgramParameter(
1222+
binary.webGLProgram,
1223+
this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) {
1224+
return this.checkCompletion_(binary);
1225+
} else {
1226+
await nextFrame();
1227+
return this.checkCompletionAsync_(binary);
1228+
}
1229+
}
1230+
1231+
private checkCompletion_(binary: GPGPUBinary): boolean {
1232+
if (this.gpgpu.gl.getProgramParameter(
1233+
binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) {
1234+
console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram));
1235+
if (this.gpgpu.gl.getShaderParameter(
1236+
binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) {
1237+
webgl_util.logShaderSourceAndInfoLog(
1238+
binary.source,
1239+
this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader));
1240+
throw new Error('Failed to compile fragment shader.');
1241+
}
1242+
throw new Error('Failed to link vertex and fragment shaders.');
1243+
}
1244+
return true;
1245+
}
1246+
1247+
getUniformLocations() {
1248+
for (const [, binary] of Object.entries(this.binaryCache)) {
1249+
const {
1250+
uniformLocations,
1251+
customUniformLocations,
1252+
infLoc,
1253+
nanLoc,
1254+
inShapesLocations,
1255+
inTexShapesLocations,
1256+
outShapeLocation,
1257+
outShapeStridesLocation,
1258+
outTexShapeLocation
1259+
} = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram);
1260+
binary.uniformLocations = uniformLocations;
1261+
binary.customUniformLocations = customUniformLocations;
1262+
binary.infLoc = infLoc;
1263+
binary.nanLoc = nanLoc;
1264+
binary.inShapesLocations = inShapesLocations;
1265+
binary.inTexShapesLocations = inTexShapesLocations;
1266+
binary.outShapeLocation = outShapeLocation;
1267+
binary.outShapeStridesLocation = outShapeStridesLocation;
1268+
binary.outTexShapeLocation = outTexShapeLocation;
1269+
}
1270+
}
11831271
}
11841272

11851273
function float32ToTypedArray<D extends NumericDataType>(

tfjs-backend-webgl/src/backend_webgl_test.ts

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,3 +1033,84 @@ describeWithFlags('custom canvas ', WEBGL_ENVS, () => {
10331033
tf.removeBackend(customBackendName);
10341034
});
10351035
});
1036+
describeWithFlags('Parallel compilation', WEBGL_ENVS, () => {
1037+
// TODO(lina128): Also test async after parallel compilation flag is
1038+
// implemented in context object. We have to keep the test sync for now,
1039+
// because it's a global flag, the async test will affect other tests.
1040+
it('does not have memory leak.', () => {
1041+
const savedWebGLCPUForward = tf.env().get('WEBGL_CPU_FORWARD');
1042+
tf.env().set('WEBGL_CPU_FORWARD', false);
1043+
1044+
const customWebGLBackendName = 'my-webgl';
1045+
tf.copyRegisteredKernels('webgl', customWebGLBackendName);
1046+
tf.registerBackend(customWebGLBackendName, () => {
1047+
return new MathBackendWebGL();
1048+
});
1049+
tf.setBackend(customWebGLBackendName);
1050+
1051+
const a0 = tf.tensor1d([1, 1, 1]);
1052+
const b0 = tf.tensor1d([1, 1, 1]);
1053+
const c0 = tf.add(a0, b0);
1054+
const data = c0.dataSync();
1055+
const numOfBinaryCacheNoParallelCompillation =
1056+
Object.keys(getBinaryCache(tf.ENV.getNumber('WEBGL_VERSION'))).length;
1057+
expectArraysClose(data, [2, 2, 2]);
1058+
tf.dispose([a0, b0, c0]);
1059+
tf.removeBackend(customWebGLBackendName);
1060+
1061+
// TODO(lina128): Also test use an existing backend after parallel
1062+
// compilation flag is implemented in context object. The current approach
1063+
// assumes there's no binary cache, and it doesn't check existing cache.
1064+
const customWebGLBackendName1 = 'my-webgl1';
1065+
tf.copyRegisteredKernels('webgl', customWebGLBackendName1);
1066+
tf.registerBackend(customWebGLBackendName1, () => {
1067+
return new MathBackendWebGL();
1068+
});
1069+
tf.setBackend(customWebGLBackendName1);
1070+
const webGLBackend = tf.backend() as MathBackendWebGL;
1071+
1072+
const startNumBytes = (tf.memory() as WebGLMemoryInfo).numBytesInGPU;
1073+
const startTensor = tf.memory().numTensors;
1074+
const startDataBuckets = webGLBackend.numDataIds();
1075+
1076+
const a1 = tf.tensor1d([1, 1, 1]);
1077+
const b1 = tf.tensor1d([1, 1, 1]);
1078+
1079+
// Pre-compile round.
1080+
tf.env().set('ENGINE_COMPILE_ONLY', true);
1081+
const c1 = tf.add(a1, b1);
1082+
webGLBackend.checkCompileCompletion();
1083+
webGLBackend.getUniformLocations();
1084+
1085+
// Warm-up upload and download round.
1086+
tf.env().set('ENGINE_COMPILE_ONLY', false);
1087+
const c2 = tf.add(a1, b1);
1088+
c2.dataSync();
1089+
1090+
// Actual inference.
1091+
const c3 = tf.add(a1, b1);
1092+
expectArraysEqual(c3.dataSync(), [2, 2, 2]);
1093+
1094+
tf.dispose([a1, b1, c1, c2, c3]);
1095+
const endNumBytes = (tf.memory() as WebGLMemoryInfo).numBytesInGPU;
1096+
const endTensor = tf.memory().numTensors;
1097+
const endDataBuckets = webGLBackend.numDataIds();
1098+
1099+
// We only check numBytesInGPU. For parallel compilation,
1100+
// numBytesInGPUAllocated will be more because of the two pass uploadToGPU,
1101+
// but they will all be freed, resulting in endNumbytes equal to
1102+
// startNumBytes.
1103+
expect(startNumBytes).toEqual(endNumBytes);
1104+
expect(startTensor).toEqual(endTensor);
1105+
expect(endDataBuckets).toEqual(startDataBuckets);
1106+
1107+
const numOfBinaryCacheWithParallelCompillation =
1108+
Object.keys(getBinaryCache(tf.ENV.getNumber('WEBGL_VERSION'))).length;
1109+
expect(numOfBinaryCacheWithParallelCompillation)
1110+
.toEqual(numOfBinaryCacheNoParallelCompillation);
1111+
1112+
tf.removeBackend(customWebGLBackendName1);
1113+
1114+
tf.env().set('WEBGL_CPU_FORWARD', savedWebGLCPUForward);
1115+
});
1116+
});

tfjs-backend-webgl/src/gpgpu_context.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import {getWebGLContext, setWebGLContext} from './canvas_util';
2121
import * as gpgpu_util from './gpgpu_util';
2222
import * as tex_util from './tex_util';
2323
import {Texture, TextureConfig} from './tex_util';
24-
import {WebGL1DisjointQueryTimerExtension, WebGL2DisjointQueryTimerExtension} from './webgl_types';
24+
import {WebGL1DisjointQueryTimerExtension, WebGL2DisjointQueryTimerExtension, WebGLParallelCompilationExtension} from './webgl_types';
2525
import * as webgl_util from './webgl_util';
2626

2727
export interface FenceContext {
@@ -37,6 +37,7 @@ export class GPGPUContext {
3737
colorBufferHalfFloatExtension: {};
3838
disjointQueryTimerExtension: WebGL2DisjointQueryTimerExtension|
3939
WebGL1DisjointQueryTimerExtension;
40+
parallelCompilationExtension: WebGLParallelCompilationExtension;
4041
vertexBuffer: WebGLBuffer;
4142
indexBuffer: WebGLBuffer;
4243
framebuffer: WebGLFramebuffer;
@@ -58,6 +59,8 @@ export class GPGPUContext {
5859
// WebGL 2.0 enables texture floats without an extension.
5960
let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
6061
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
62+
this.parallelCompilationExtension =
63+
this.gl.getExtension('KHR_parallel_shader_compile');
6164
if (env().getNumber('WEBGL_VERSION') === 1) {
6265
const TEXTURE_FLOAT = 'OES_texture_float';
6366
const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';

tfjs-backend-webgl/src/gpgpu_math.ts

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ export interface GPGPUBinary {
6464
outTexShapeLocation?: WebGLUniformLocation;
6565
}
6666

67+
export interface GPGPUBinaryLocations {
68+
uniformLocations: {[name: string]: WebGLUniformLocation};
69+
customUniformLocations?: WebGLUniformLocation[];
70+
infLoc: WebGLUniformLocation;
71+
nanLoc: WebGLUniformLocation;
72+
inShapesLocations?: {[name: string]: WebGLUniformLocation};
73+
inTexShapesLocations?: {[name: string]: WebGLUniformLocation};
74+
outShapeLocation?: WebGLUniformLocation;
75+
outShapeStridesLocation?: WebGLUniformLocation;
76+
outTexShapeLocation?: WebGLUniformLocation;
77+
}
78+
6779
export interface TensorData {
6880
shape: number[];
6981
texData: TextureData;
@@ -101,18 +113,58 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
101113
const fragmentShader = createFragmentShader(gpgpu.gl, source);
102114
const webGLProgram = gpgpu.createProgram(fragmentShader);
103115

104-
// Add special uniforms (NAN, INFINITY)
116+
if (!env().get('ENGINE_COMPILE_ONLY')) {
117+
return {
118+
program,
119+
fragmentShader,
120+
source,
121+
webGLProgram,
122+
inShapeInfos,
123+
outShapeInfo,
124+
...getUniformLocations(gpgpu, program, webGLProgram)
125+
};
126+
} else {
127+
return {
128+
program,
129+
fragmentShader,
130+
source,
131+
webGLProgram,
132+
inShapeInfos,
133+
outShapeInfo,
134+
uniformLocations: null,
135+
customUniformLocations: null,
136+
infLoc: null,
137+
nanLoc: null,
138+
inShapesLocations: null,
139+
inTexShapesLocations: null,
140+
outShapeLocation: null,
141+
outShapeStridesLocation: null,
142+
outTexShapeLocation: null
143+
};
144+
}
145+
}
146+
147+
export function getUniformLocations(
148+
gpgpu: GPGPUContext, program: GPGPUProgram,
149+
webGLProgram: WebGLProgram): GPGPUBinaryLocations {
150+
const uniformLocations: {[name: string]: WebGLUniformLocation} = {};
151+
const inShapesLocations: {[name: string]: WebGLUniformLocation} = {};
152+
const inTexShapesLocations: {[name: string]: WebGLUniformLocation} = {};
153+
const customUniformLocations: WebGLUniformLocation[] = [];
154+
let outShapeLocation: WebGLUniformLocation;
155+
let outTexShapeLocation: WebGLUniformLocation;
156+
let outShapeStridesLocation: WebGLUniformLocation;
105157
let infLoc: WebGLUniformLocation = null;
106-
const nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
158+
let nanLoc: WebGLUniformLocation = null;
159+
160+
// Add special uniforms (NAN, INFINITY)
161+
nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
107162
if (env().getNumber('WEBGL_VERSION') === 1) {
108163
infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
109164
}
110165

111166
// Add user-defined uniforms
112167
const shouldThrow = false;
113-
const uniformLocations: {[name: string]: WebGLUniformLocation} = {};
114-
const inShapesLocations: {[name: string]: WebGLUniformLocation} = {};
115-
const inTexShapesLocations: {[name: string]: WebGLUniformLocation} = {};
116168
for (let i = 0; i < program.variableNames.length; i++) {
117169
const varName = program.variableNames[i];
118170
uniformLocations[varName] =
@@ -127,9 +179,6 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
127179
}
128180
}
129181

130-
let outShapeLocation: WebGLUniformLocation;
131-
let outTexShapeLocation: WebGLUniformLocation;
132-
let outShapeStridesLocation: WebGLUniformLocation;
133182
if (program.enableShapeUniforms) {
134183
outShapeLocation =
135184
gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
@@ -139,7 +188,6 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
139188
gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
140189
}
141190

142-
const customUniformLocations: WebGLUniformLocation[] = [];
143191
if (program.customUniforms) {
144192
program.customUniforms.forEach((d, i) => {
145193
customUniformLocations[i] =
@@ -148,14 +196,8 @@ export function compileProgram<T extends Tensor, K extends Tensor>(
148196
}
149197

150198
return {
151-
program,
152-
fragmentShader,
153-
source,
154-
webGLProgram,
155199
uniformLocations,
156200
customUniformLocations,
157-
inShapeInfos,
158-
outShapeInfo,
159201
infLoc,
160202
nanLoc,
161203
inShapesLocations,

0 commit comments

Comments
 (0)