Skip to content

Commit c027d6a

Browse files
Vectorrentpyu10055mattsoulanille
authored
Add support for GELU and approximate activation functions (#8224)
FEATURE * add docker configs for isolated testing * implement gelu and gelu_new as separate activations * Update activations.ts * Update activations_test.ts * Update activations_test.ts * remove docker files * fix activation tests * fix lint errors * remove extra blank line * fix gelu_new calc * fix 1D test --------- Co-authored-by: Ping Yu <[email protected]> Co-authored-by: Matthew Soulanille <[email protected]>
1 parent baf2364 commit c027d6a

File tree

4 files changed

+166
-12
lines changed

4 files changed

+166
-12
lines changed

tfjs-layers/src/activations.ts

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,64 @@ export class LogSoftmax extends Activation {
209209
serialization.registerClass(LogSoftmax);
210210

211211
/**
212-
* Swish activation function
212+
* Gelu activation function
213213
*/
214-
export class Swish extends Activation {
214+
export class Gelu extends Activation {
215215
/** @nocollapse */
216-
static readonly className = 'swish';
216+
static readonly className = 'gelu';
217217
/**
218218
* Calculate the activation function.
219219
*
220220
* @param x Tensor.
221-
* @param alpha Scaling factor for the sigmoid function.
222221
* @returns a Tensor of the same shape as x
223222
*/
224-
apply(x: Tensor, alpha = 1): Tensor {
225-
return tidy(() => tfc.mul(tfc.sigmoid(tfc.mul(x, alpha)), x));
223+
apply(x: Tensor): Tensor {
224+
return tidy(() => {
225+
return tfc.tidy(() => {
226+
const sqrtTwo = Math.sqrt(2);
227+
// Compute Φ(x) using the erf function
228+
const cdf = tfc.mul(0.5, tfc.add(1, tfc.erf(tfc.div(x, sqrtTwo))));
229+
// Compute GELU(x) = x * Φ(x)
230+
return tfc.mul(x, cdf);
231+
});
232+
});
226233
}
227234
}
228-
serialization.registerClass(Swish);
235+
serialization.registerClass(Gelu);
236+
237+
/**
238+
* GeluNew activation function
239+
*/
240+
export class GeluNew extends Activation {
241+
/** @nocollapse */
242+
static readonly className = 'gelu_new';
243+
/**
244+
* Calculate the activation function.
245+
*
246+
* @param x Tensor.
247+
* @returns a Tensor of the same shape as x
248+
*/
249+
apply(x: Tensor): Tensor {
250+
return tidy(() => {
251+
return tfc.mul(
252+
0.5,
253+
tfc.mul(
254+
x,
255+
tfc.add(
256+
1,
257+
tfc.tanh(
258+
tfc.mul(
259+
tfc.sqrt(tfc.div(2, Math.PI)),
260+
tfc.add(x, tfc.mul(0.044715, tfc.pow(x, 3)))
261+
)
262+
)
263+
)
264+
)
265+
);
266+
});
267+
}
268+
}
269+
serialization.registerClass(GeluNew);
229270

230271
/**
231272
* Mish activation function
@@ -245,6 +286,25 @@ export class Mish extends Activation {
245286
}
246287
serialization.registerClass(Mish);
247288

289+
/**
290+
* Swish activation function
291+
*/
292+
export class Swish extends Activation {
293+
/** @nocollapse */
294+
static readonly className = 'swish';
295+
/**
296+
* Calculate the activation function.
297+
*
298+
* @param x Tensor.
299+
* @param alpha Scaling factor for the sigmoid function.
300+
* @returns a Tensor of the same shape as x
301+
*/
302+
apply(x: Tensor, alpha = 1): Tensor {
303+
return tidy(() => tfc.mul(tfc.sigmoid(tfc.mul(x, alpha)), x));
304+
}
305+
}
306+
serialization.registerClass(Swish);
307+
248308
export function serializeActivation(activation: Activation): string {
249309
return activation.getClassName();
250310
}

tfjs-layers/src/activations_test.ts

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
*/
1414
import {scalar, tensor1d, tensor2d, tensor3d} from '@tensorflow/tfjs-core';
1515

16-
import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish, Mish} from './activations';
16+
import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish, Mish, Gelu, GeluNew} from './activations';
1717
import {describeMathCPUAndGPU, expectNoLeakedTensors, expectTensorsClose} from './utils/test_utils';
1818

1919
describeMathCPUAndGPU('linear activation', () => {
@@ -366,3 +366,98 @@ describeMathCPUAndGPU('mish activation', () => {
366366
expectNoLeakedTensors(() => mish(initX), 1);
367367
});
368368
});
369+
370+
describeMathCPUAndGPU('gelu activation', () => {
371+
const gelu = new Gelu().apply;
372+
// Setup: Array with initial values.
373+
// Execute: Gelu on the last dimension.
374+
// Expect: Output array matches size and approximate expected values.
375+
it('1D', () => {
376+
const initX = tensor1d([0, 1, 3, 9]);
377+
const expectedVals = tensor1d([
378+
0,
379+
0.8413447141647339,
380+
2.995950222015381, 9
381+
]);
382+
expectTensorsClose(gelu(initX), expectedVals);
383+
});
384+
it('1D all equal', () => {
385+
const initX = tensor1d([-1, -1, -1, -1]);
386+
const expectedVals = tensor1d([
387+
-0.15865525603294373,
388+
-0.15865525603294373,
389+
-0.15865525603294373,
390+
-0.15865525603294373
391+
]);
392+
expectTensorsClose(gelu(initX), expectedVals);
393+
});
394+
it('2D', () => {
395+
const initX = tensor2d([[0, 1, 3, 9], [0, 1, 3, 9]]);
396+
const expectedVals = tensor2d([
397+
[0, 0.8413447141647339, 2.995950222015381, 9],
398+
[0, 0.8413447141647339, 2.995950222015381, 9]
399+
]);
400+
expectTensorsClose(gelu(initX), expectedVals);
401+
});
402+
it('3D', () => {
403+
const initX = tensor3d([[[0, 1, 3, 9], [0, 1, 3, 9]]]);
404+
const expectedVals = tensor3d([[
405+
[ 0, 0.8413447141647339, 2.995950222015381, 9 ],
406+
[ 0, 0.8413447141647339, 2.995950222015381, 9 ]
407+
]]);
408+
expectTensorsClose(gelu(initX), expectedVals);
409+
});
410+
it('Does not leak', () => {
411+
const initX = tensor1d([0, 1, 3, 9]);
412+
expectNoLeakedTensors(() => gelu(initX), 1);
413+
});
414+
});
415+
416+
describeMathCPUAndGPU('gelu_new activation', () => {
417+
const geluNew = new GeluNew().apply;
418+
// Setup: Array with initial values.
419+
// Execute: GeluNew on the last dimension.
420+
// Expect: Output array matches size and approximate expected values.
421+
it('1D', () => {
422+
const initX = tensor1d([0, 1, 3, 9]);
423+
const expectedVals = tensor1d([
424+
0,
425+
0.8411920070648193,
426+
2.9963626861572266,
427+
9
428+
]);
429+
expectTensorsClose(geluNew(initX), expectedVals);
430+
});
431+
it('1D all equal', () => {
432+
const initX = tensor1d([-1, -1, -1, -1]);
433+
const expectedVals = tensor1d([
434+
-0.15880802273750305,
435+
-0.15880802273750305,
436+
-0.15880802273750305,
437+
-0.15880802273750305
438+
]);
439+
expectTensorsClose(geluNew(initX), expectedVals);
440+
});
441+
it('2D', () => {
442+
const initX = tensor2d([[0, 1, 3, 9], [0, 1, 3, 9]]);
443+
const expectedVals = tensor2d([
444+
[ 0, 0.8411920070648193, 2.9963626861572266, 9 ],
445+
[ 0, 0.8411920070648193, 2.9963626861572266, 9 ]
446+
]);
447+
expectTensorsClose(geluNew(initX), expectedVals);
448+
});
449+
it('3D', () => {
450+
const initX = tensor3d([[[0, 1, 3, 9], [0, 1, 3, 9]]]);
451+
const expectedVals = tensor3d([
452+
[
453+
[ 0, 0.8411920070648193, 2.9963626861572266, 9 ],
454+
[ 0, 0.8411920070648193, 2.9963626861572266, 9 ]
455+
]
456+
]);
457+
expectTensorsClose(geluNew(initX), expectedVals);
458+
});
459+
it('Does not leak', () => {
460+
const initX = tensor1d([0, 1, 3, 9]);
461+
expectNoLeakedTensors(() => geluNew(initX), 1);
462+
});
463+
});

tfjs-layers/src/keras_format/activation_config.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {stringLiteralArray} from './utils';
1515
*/
1616
export const activationOptions = stringLiteralArray([
1717
'elu', 'hard_sigmoid', 'linear', 'relu', 'relu6', 'selu', 'sigmoid',
18-
'softmax', 'softplus', 'softsign', 'tanh', 'swish', 'mish'
18+
'softmax', 'softplus', 'softsign', 'tanh', 'swish', 'mish', 'gelu', 'gelu_new'
1919
]);
2020

2121
/**
@@ -28,4 +28,4 @@ export type ActivationSerialization = typeof activationOptions[number];
2828
// e.g. to src/common.ts. Maybe even duplicate *all* of these to be pedantic?
2929
/** @docinline */
3030
export type ActivationIdentifier = 'elu'|'hardSigmoid'|'linear'|'relu'|'relu6'|
31-
'selu'|'sigmoid'|'softmax'|'softplus'|'softsign'|'tanh'|'swish'|'mish';
31+
'selu'|'sigmoid'|'softmax'|'softplus'|'softsign'|'tanh'|'swish'|'mish'|'gelu'|'gelu_new';

tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ export class GPT2Backbone extends Backbone {
170170
numHeads: args.numHeads,
171171
dropout: args.dropout,
172172
layerNormEpsilon: 1e-05,
173-
// TODO(pforderique): Implement gelu.
174-
activation: getActivation('relu'),
173+
activation: getActivation('gelu'),
175174
kernelInitializer: gpt2KernelInitializer(0.02),
176175
normalizeFirst: true,
177176
name: `transformer_layer_${i}`,

0 commit comments

Comments
 (0)