Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit c254d95

Browse files
authored
Migrate binary op to use logical sampling and add broadcasting support (#40)
* migrate binary op to logical and add broadcasting * fix lint and build error * optimize the getAAtOutputCoord sampler and getFlat sampler
1 parent edb0eba commit c254d95

17 files changed

+502
-688
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"editor.tabSize": 2,
1313
"editor.insertSpaces": true,
1414
"files.insertFinalNewline": true,
15-
"editor.detectIndentation": false
15+
"editor.detectIndentation": false,
16+
"typescript.tsdk": "node_modules/typescript/lib"
1617
}

demos/mnist/mnist.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ export function buildModelMathAPI(
6868

6969
return (x: Array1D): Scalar => {
7070
return math.scope(() => {
71-
const hidden1 =
72-
math.relu(math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B));
73-
const hidden2 = math.relu(
74-
math.add(math.vectorTimesMatrix(hidden1, hidden2W), hidden2B));
71+
const hidden1 = math.relu(
72+
math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B)) as Array1D;
73+
const hidden2 = math.relu(math.add(
74+
math.vectorTimesMatrix(hidden1, hidden2W), hidden2B)) as Array1D;
7575
const logits =
7676
math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB);
7777
return math.argMax(logits);

src/math/activation_functions.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ export class SigmoidFunc implements ActivationFunction {
5959
});
6060
}
6161

62-
der<T extends NDArray>(math: NDArrayMath, x: T, y: T) {
62+
der<T extends NDArray>(math: NDArrayMath, x: T, y: T): T {
6363
return math.scope(() => {
6464
// y * (1 - y) = y - y^2
6565
const ySquared = math.elementWiseMul(y, y);
66-
return math.sub(y, ySquared);
66+
return math.subStrict(y, ySquared);
6767
});
6868
}
6969
}

src/math/cost_functions.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export class SquareCostFunc implements ElementWiseCostFunction {
2929
private halfOne = Scalar.new(0.5);
3030

3131
cost<T extends NDArray>(math: NDArrayMath, x1: T, x2: T): T {
32-
const diff = math.sub(x1, x2);
32+
const diff = math.subStrict(x1, x2);
3333
const diffSquared = math.elementWiseMul(diff, diff);
3434
const result = math.scalarTimesArray(this.halfOne, diffSquared);
3535

@@ -40,7 +40,7 @@ export class SquareCostFunc implements ElementWiseCostFunction {
4040
}
4141

4242
der<T extends NDArray>(math: NDArrayMath, x1: T, x2: T): T {
43-
return math.sub(x1, x2);
43+
return math.subStrict(x1, x2);
4444
}
4545

4646
dispose() {

src/math/math.ts

Lines changed: 89 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,8 @@ export abstract class NDArrayMath {
558558
c.size === 1,
559559
`Error in scalarPlusArray: first argument must be rank 0, but got ` +
560560
`rank ${c.rank}.`);
561-
return this.track(this.scalarPlusArrayInternal(c, a));
561+
return this.add(c, a) as T;
562562
}
563-
protected abstract scalarPlusArrayInternal<T extends NDArray>(
564-
c: Scalar, a: T): T;
565563

566564
/**
567565
* Computes a scalar minus NDArray, c - A.
@@ -573,25 +571,21 @@ export abstract class NDArrayMath {
573571
c.size === 1,
574572
`Error in scalarMinusArray: first argument must be rank 0, but got ` +
575573
`rank ${c.rank}.`);
576-
return this.track(this.scalarMinusArrayInternal(c, a));
574+
return this.sub(c, a) as T;
577575
}
578-
protected abstract scalarMinusArrayInternal<T extends NDArray>(
579-
c: Scalar, a: T): T;
580576

581577
/**
582-
* Computes a scalar minus NDArray, A - c.
578+
* Computes A - c. A is NDArray, c is Scalar.
583579
* @param a The NDArray A in A - c.
584-
* @param c The scalar c in A - c.
580+
* @param c The Scalar c in A - c.
585581
*/
586582
arrayMinusScalar<T extends NDArray>(a: T, c: Scalar): T {
587583
util.assert(
588584
c.size === 1,
589585
`Error in arrayMinusScalar: second argument must be rank 0, but ` +
590586
`got rank ${c.rank}.`);
591-
return this.track(this.arrayMinusScalarInternal(a, c));
587+
return this.sub(a, c) as T;
592588
}
593-
protected abstract arrayMinusScalarInternal<T extends NDArray>(
594-
a: T, c: Scalar): T;
595589

596590
/**
597591
* Computes -1 * A element-wise.
@@ -603,50 +597,111 @@ export abstract class NDArrayMath {
603597
protected abstract negInternal<T extends NDArray>(a: T): T;
604598

605599
/**
606-
* Adds two NDArrays element-wise, A + B. Inputs must be the same shape.
600+
* Adds two NDArrays element-wise, A + B. Supports broadcasting.
601+
* For a stricter version without broadcasting use math.addStrict().
602+
*
607603
* @param a The first NDArray to add element-wise.
608604
* @param b The second NDArray to add element-wise.
609605
*/
610-
add<T extends NDArray>(a: T, b: T): T {
611-
util.assertShapesMatch(a.shape, b.shape, 'Error in add: ');
606+
add(a: NDArray, b: NDArray): NDArray {
607+
util.assertAndGetBroadcastedShape(a.shape, b.shape);
612608
return this.track(this.addInternal(a, b));
613609
}
614-
protected abstract addInternal<T extends NDArray>(a: T, b: T): T;
610+
protected abstract addInternal(a: NDArray, b: NDArray): NDArray;
611+
612+
/**
613+
* Adds two NDArrays element-wise, A + B. Inputs must
614+
* be the same shape. For broadcasting support, use math.add() instead.
615+
*
616+
* @param a The first NDArray to multiply element-wise.
617+
* @param b The second NDArray to multiply element-wise.
618+
*/
619+
addStrict<T extends NDArray>(a: T, b: T): T {
620+
util.assertShapesMatch(a.shape, b.shape, 'Error in addStrict: ');
621+
return this.add(a, b) as T;
622+
}
615623

616624
/**
617-
* Subtracts two NDArrays element-wise, A - B. Inputs must be the same shape.
625+
* Subtracts two NDArrays element-wise, A - B. Supports broadcasting.
626+
* For a stricter version without broadcasting use math.subStrict().
627+
*
618628
* @param a The first NDArray to subtract element-wise.
619629
* @param b The second NDArray to subtract element-wise.
620630
*/
621-
sub<T extends NDArray>(a: T, b: T): T {
622-
util.assertShapesMatch(a.shape, b.shape, 'Error in sub: ');
631+
sub(a: NDArray, b: NDArray): NDArray {
632+
util.assertAndGetBroadcastedShape(a.shape, b.shape);
623633
return this.track(this.subInternal(a, b));
624634
}
625-
protected abstract subInternal<T extends NDArray>(a: T, b: T): T;
635+
protected abstract subInternal(a: NDArray, b: NDArray): NDArray;
626636

627637
/**
628-
* Multiplies two NDArrays element-wise (hadamard product), A * B. Inputs must
629-
* be the same shape.
638+
* Subtracts two NDArrays element-wise, A - B. Inputs must
639+
* be the same shape. For broadcasting support, use math.sub() instead.
640+
*
630641
* @param a The first NDArray to multiply element-wise.
631642
* @param b The second NDArray to multiply element-wise.
632643
*/
644+
subStrict<T extends NDArray>(a: T, b: T): T {
645+
util.assertShapesMatch(a.shape, b.shape, 'Error in subStrict: ');
646+
return this.sub(a, b) as T;
647+
}
648+
649+
/**
650+
* Multiplies two NDArrays element-wise, A * B. Supports broadcasting.
651+
* For a stricter version without broadcasting use math.multiplyStrict().
652+
*
653+
* @param a The first NDArray to multiply element-wise.
654+
* @param b The second NDArray to multiply element-wise.
655+
*/
656+
multiply(a: NDArray, b: NDArray): NDArray {
657+
util.assertAndGetBroadcastedShape(a.shape, b.shape);
658+
return this.track(this.multiplyInternal(a, b));
659+
}
660+
protected abstract multiplyInternal<T extends NDArray>(a: T, b: T): T;
661+
662+
/**
663+
* @deprecated Use math.multiplyStrict() instead.
664+
*/
633665
elementWiseMul<T extends NDArray>(a: T, b: T): T {
634-
util.assertShapesMatch(a.shape, b.shape, 'Error in elementWiseMul: ');
635-
return this.track(this.elementWiseMulInternal(a, b));
666+
return this.multiplyStrict(a, b);
667+
}
668+
669+
/**
670+
* Multiplies two NDArrays element-wise, A * B. Inputs must
671+
* be the same shape. For broadcasting support, use math.multiply() instead.
672+
*
673+
* @param a The first NDArray to multiply element-wise.
674+
* @param b The second NDArray to multiply element-wise.
675+
*/
676+
multiplyStrict<T extends NDArray>(a: T, b: T): T {
677+
util.assertShapesMatch(a.shape, b.shape, 'Error in multiplyStrict: ');
678+
return this.multiply(a, b) as T;
636679
}
637-
protected abstract elementWiseMulInternal<T extends NDArray>(a: T, b: T): T;
638680

639681
/**
640-
* Divides two NDArrays element-wise (hadamard product), A / B. Inputs must be
641-
* the same shape.
682+
* Divides two NDArrays element-wise, A / B. Supports broadcasting.
683+
* For a stricter version without broadcasting use math.divideStrict().
684+
*
642685
* @param a The first NDArray to divide element-wise.
643686
* @param b The second NDArray to divide element-wise.
644687
*/
645-
divide<T extends NDArray>(a: T, b: T): T {
646-
util.assertShapesMatch(a.shape, b.shape, 'Error in divide: ');
688+
divide(a: NDArray, b: NDArray): NDArray {
689+
util.assertAndGetBroadcastedShape(a.shape, b.shape);
647690
return this.track(this.divideInternal(a, b));
648691
}
649-
protected abstract divideInternal<T extends NDArray>(a: T, b: T): T;
692+
protected abstract divideInternal(a: NDArray, b: NDArray): NDArray;
693+
694+
/**
695+
* Divides two NDArrays element-wise, A / B. Inputs must
696+
* be the same shape. For broadcasting support, use math.divide() instead.
697+
*
698+
* @param a The first NDArray to multiply element-wise.
699+
* @param b The second NDArray to multiply element-wise.
700+
*/
701+
divideStrict<T extends NDArray>(a: T, b: T): T {
702+
util.assertShapesMatch(a.shape, b.shape, 'Error in divideStrict: ');
703+
return this.divide(a, b) as T;
704+
}
650705

651706
/**
652707
* Computes a scalar divided by an NDArray, broadcasted over the NDArray, c /
@@ -659,10 +714,8 @@ export abstract class NDArrayMath {
659714
c.size === 1,
660715
`Error in scalarDividedByArray: first argument must be rank 0, but ` +
661716
`got NDArray of rank ${c.rank}.`);
662-
return this.track(this.scalarDividedByArrayInternal(c, a));
717+
return this.divide(c, a) as T;
663718
}
664-
protected abstract scalarDividedByArrayInternal<T extends NDArray>(
665-
c: Scalar, a: T): T;
666719

667720
/**
668721
* Computes an NDArray divided by a scalar, broadcasted over the NDArray, A /
@@ -675,10 +728,8 @@ export abstract class NDArrayMath {
675728
c.size === 1,
676729
`Error in arrayDividedByScalar: second argument must be rank 0, ` +
677730
`but got NDArray of rank ${c.rank}.`);
678-
return this.track(this.arrayDividedByScalarInternal(a, c));
731+
return this.divide(a, c) as T;
679732
}
680-
protected abstract arrayDividedByScalarInternal<T extends NDArray>(
681-
a: T, c: Scalar): T;
682733

683734
/**
684735
* Computes exponential of the input NDArray element-wise. y = e ^ x
@@ -778,17 +829,11 @@ export abstract class NDArrayMath {
778829
c.size === 1,
779830
`Error in arrayDividedByScalar: first argument must be rank 0, but ` +
780831
`got rank ${c.rank}.`);
781-
return this.track(this.scalarTimesArrayInternal(c, a));
832+
return this.multiply(c, a) as T;
782833
}
783-
protected abstract scalarTimesArrayInternal<T extends NDArray>(
784-
c: Scalar, a: T): T;
785834

786835
/**
787-
* Computes an element-wise broadcasted multiplication of two matrices A and
788-
* B. Will return a new matrix that is the max of A and B, where the smaller
789-
* matrix will broadcast over the larger matrix.
790-
* @param c The scalar in the operation.
791-
* @param A the NDArray in the operation that will be broadcasted over.
836+
* @deprecated Use math.multiply() instead.
792837
*/
793838
elementWiseMulBroadcast(a: Array2D, b: Array2D): Array2D {
794839
util.assert(
@@ -799,10 +844,8 @@ export abstract class NDArrayMath {
799844
b.rank === 2,
800845
`Error in elementWiseMulBroadcast: second argument must be ` +
801846
`rank 2, but got rank ${b.rank}.`);
802-
return this.track(this.elementWiseMulBroadcastInternal(a, b));
847+
return this.multiply(a, b) as Array2D;
803848
}
804-
protected abstract elementWiseMulBroadcastInternal(a: Array2D, b: Array2D):
805-
Array2D;
806849

807850
/////////////////////
808851
// Convolution ops //

0 commit comments

Comments
 (0)