@@ -558,10 +558,8 @@ export abstract class NDArrayMath {
558558 c . size === 1 ,
559559 `Error in scalarPlusArray: first argument must be rank 0, but got ` +
560560 `rank ${ c . rank } .` ) ;
561- return this . track ( this . scalarPlusArrayInternal ( c , a ) ) ;
561+ return this . add ( c , a ) as T ;
562562 }
563- protected abstract scalarPlusArrayInternal < T extends NDArray > (
564- c : Scalar , a : T ) : T ;
565563
566564 /**
567565 * Computes a scalar minus NDArray, c - A.
@@ -573,25 +571,21 @@ export abstract class NDArrayMath {
573571 c . size === 1 ,
574572 `Error in scalarMinusArray: first argument must be rank 0, but got ` +
575573 `rank ${ c . rank } .` ) ;
576- return this . track ( this . scalarMinusArrayInternal ( c , a ) ) ;
574+ return this . sub ( c , a ) as T ;
577575 }
578- protected abstract scalarMinusArrayInternal < T extends NDArray > (
579- c : Scalar , a : T ) : T ;
580576
581577 /**
582- * Computes a scalar minus NDArray, A - c .
578+ * Computes A - c. A is NDArray, c is Scalar .
583579 * @param a The NDArray A in A - c.
584- * @param c The scalar c in A - c.
580+ * @param c The Scalar c in A - c.
585581 */
586582 arrayMinusScalar < T extends NDArray > ( a : T , c : Scalar ) : T {
587583 util . assert (
588584 c . size === 1 ,
589585 `Error in arrayMinusScalar: second argument must be rank 0, but ` +
590586 `got rank ${ c . rank } .` ) ;
591- return this . track ( this . arrayMinusScalarInternal ( a , c ) ) ;
587+ return this . sub ( a , c ) as T ;
592588 }
593- protected abstract arrayMinusScalarInternal < T extends NDArray > (
594- a : T , c : Scalar ) : T ;
595589
596590 /**
597591 * Computes -1 * A element-wise.
@@ -603,50 +597,111 @@ export abstract class NDArrayMath {
603597 protected abstract negInternal < T extends NDArray > ( a : T ) : T ;
604598
605599 /**
606- * Adds two NDArrays element-wise, A + B. Inputs must be the same shape.
600+ * Adds two NDArrays element-wise, A + B. Supports broadcasting.
601+ * For a stricter version without broadcasting use math.addStrict().
602+ *
607603 * @param a The first NDArray to add element-wise.
608604 * @param b The second NDArray to add element-wise.
609605 */
610- add < T extends NDArray > ( a : T , b : T ) : T {
611- util . assertShapesMatch ( a . shape , b . shape , 'Error in add: ' ) ;
606+ add ( a : NDArray , b : NDArray ) : NDArray {
607+ util . assertAndGetBroadcastedShape ( a . shape , b . shape ) ;
612608 return this . track ( this . addInternal ( a , b ) ) ;
613609 }
614- protected abstract addInternal < T extends NDArray > ( a : T , b : T ) : T ;
610+ protected abstract addInternal ( a : NDArray , b : NDArray ) : NDArray ;
611+
612+ /**
613+ * Adds two NDArrays element-wise, A + B. Inputs must
614+ * be the same shape. For broadcasting support, use math.add() instead.
615+ *
616+ * @param a The first NDArray to multiply element-wise.
617+ * @param b The second NDArray to multiply element-wise.
618+ */
619+ addStrict < T extends NDArray > ( a : T , b : T ) : T {
620+ util . assertShapesMatch ( a . shape , b . shape , 'Error in addStrict: ' ) ;
621+ return this . add ( a , b ) as T ;
622+ }
615623
616624 /**
617- * Subtracts two NDArrays element-wise, A - B. Inputs must be the same shape.
625+ * Subtracts two NDArrays element-wise, A - B. Supports broadcasting.
626+ * For a stricter version without broadcasting use math.subStrict().
627+ *
618628 * @param a The first NDArray to subtract element-wise.
619629 * @param b The second NDArray to subtract element-wise.
620630 */
621- sub < T extends NDArray > ( a : T , b : T ) : T {
622- util . assertShapesMatch ( a . shape , b . shape , 'Error in sub: ' ) ;
631+ sub ( a : NDArray , b : NDArray ) : NDArray {
632+ util . assertAndGetBroadcastedShape ( a . shape , b . shape ) ;
623633 return this . track ( this . subInternal ( a , b ) ) ;
624634 }
625- protected abstract subInternal < T extends NDArray > ( a : T , b : T ) : T ;
635+ protected abstract subInternal ( a : NDArray , b : NDArray ) : NDArray ;
626636
627637 /**
628- * Multiplies two NDArrays element-wise (hadamard product), A * B. Inputs must
629- * be the same shape.
638+ * Subtracts two NDArrays element-wise, A - B. Inputs must
639+ * be the same shape. For broadcasting support, use math.sub() instead.
640+ *
630641 * @param a The first NDArray to multiply element-wise.
631642 * @param b The second NDArray to multiply element-wise.
632643 */
644+ subStrict < T extends NDArray > ( a : T , b : T ) : T {
645+ util . assertShapesMatch ( a . shape , b . shape , 'Error in subStrict: ' ) ;
646+ return this . sub ( a , b ) as T ;
647+ }
648+
649+ /**
650+ * Multiplies two NDArrays element-wise, A * B. Supports broadcasting.
651+ * For a stricter version without broadcasting use math.multiplyStrict().
652+ *
653+ * @param a The first NDArray to multiply element-wise.
654+ * @param b The second NDArray to multiply element-wise.
655+ */
656+ multiply ( a : NDArray , b : NDArray ) : NDArray {
657+ util . assertAndGetBroadcastedShape ( a . shape , b . shape ) ;
658+ return this . track ( this . multiplyInternal ( a , b ) ) ;
659+ }
660+ protected abstract multiplyInternal < T extends NDArray > ( a : T , b : T ) : T ;
661+
662+ /**
663+ * @deprecated Use math.multiplyStrict() instead.
664+ */
633665 elementWiseMul < T extends NDArray > ( a : T , b : T ) : T {
634- util . assertShapesMatch ( a . shape , b . shape , 'Error in elementWiseMul: ' ) ;
635- return this . track ( this . elementWiseMulInternal ( a , b ) ) ;
666+ return this . multiplyStrict ( a , b ) ;
667+ }
668+
669+ /**
670+ * Multiplies two NDArrays element-wise, A * B. Inputs must
671+ * be the same shape. For broadcasting support, use math.multiply() instead.
672+ *
673+ * @param a The first NDArray to multiply element-wise.
674+ * @param b The second NDArray to multiply element-wise.
675+ */
676+ multiplyStrict < T extends NDArray > ( a : T , b : T ) : T {
677+ util . assertShapesMatch ( a . shape , b . shape , 'Error in multiplyStrict: ' ) ;
678+ return this . multiply ( a , b ) as T ;
636679 }
637- protected abstract elementWiseMulInternal < T extends NDArray > ( a : T , b : T ) : T ;
638680
639681 /**
640- * Divides two NDArrays element-wise (hadamard product), A / B. Inputs must be
641- * the same shape.
682+ * Divides two NDArrays element-wise, A / B. Supports broadcasting.
683+ * For a stricter version without broadcasting use math.divideStrict().
684+ *
642685 * @param a The first NDArray to divide element-wise.
643686 * @param b The second NDArray to divide element-wise.
644687 */
645- divide < T extends NDArray > ( a : T , b : T ) : T {
646- util . assertShapesMatch ( a . shape , b . shape , 'Error in divide: ' ) ;
688+ divide ( a : NDArray , b : NDArray ) : NDArray {
689+ util . assertAndGetBroadcastedShape ( a . shape , b . shape ) ;
647690 return this . track ( this . divideInternal ( a , b ) ) ;
648691 }
649- protected abstract divideInternal < T extends NDArray > ( a : T , b : T ) : T ;
692+ protected abstract divideInternal ( a : NDArray , b : NDArray ) : NDArray ;
693+
694+ /**
695+ * Divides two NDArrays element-wise, A / B. Inputs must
696+ * be the same shape. For broadcasting support, use math.divide() instead.
697+ *
698+ * @param a The first NDArray to multiply element-wise.
699+ * @param b The second NDArray to multiply element-wise.
700+ */
701+ divideStrict < T extends NDArray > ( a : T , b : T ) : T {
702+ util . assertShapesMatch ( a . shape , b . shape , 'Error in divideStrict: ' ) ;
703+ return this . divide ( a , b ) as T ;
704+ }
650705
651706 /**
652707 * Computes a scalar divided by an NDArray, broadcasted over the NDArray, c /
@@ -659,10 +714,8 @@ export abstract class NDArrayMath {
659714 c . size === 1 ,
660715 `Error in scalarDividedByArray: first argument must be rank 0, but ` +
661716 `got NDArray of rank ${ c . rank } .` ) ;
662- return this . track ( this . scalarDividedByArrayInternal ( c , a ) ) ;
717+ return this . divide ( c , a ) as T ;
663718 }
664- protected abstract scalarDividedByArrayInternal < T extends NDArray > (
665- c : Scalar , a : T ) : T ;
666719
667720 /**
668721 * Computes an NDArray divided by a scalar, broadcasted over the NDArray, A /
@@ -675,10 +728,8 @@ export abstract class NDArrayMath {
675728 c . size === 1 ,
676729 `Error in arrayDividedByScalar: second argument must be rank 0, ` +
677730 `but got NDArray of rank ${ c . rank } .` ) ;
678- return this . track ( this . arrayDividedByScalarInternal ( a , c ) ) ;
731+ return this . divide ( a , c ) as T ;
679732 }
680- protected abstract arrayDividedByScalarInternal < T extends NDArray > (
681- a : T , c : Scalar ) : T ;
682733
683734 /**
684735 * Computes exponential of the input NDArray element-wise. y = e ^ x
@@ -778,17 +829,11 @@ export abstract class NDArrayMath {
778829 c . size === 1 ,
779830 `Error in arrayDividedByScalar: first argument must be rank 0, but ` +
780831 `got rank ${ c . rank } .` ) ;
781- return this . track ( this . scalarTimesArrayInternal ( c , a ) ) ;
832+ return this . multiply ( c , a ) as T ;
782833 }
783- protected abstract scalarTimesArrayInternal < T extends NDArray > (
784- c : Scalar , a : T ) : T ;
785834
786835 /**
787- * Computes an element-wise broadcasted multiplication of two matrices A and
788- * B. Will return a new matrix that is the max of A and B, where the smaller
789- * matrix will broadcast over the larger matrix.
790- * @param c The scalar in the operation.
791- * @param A the NDArray in the operation that will be broadcasted over.
836+ * @deprecated Use math.multiply() instead.
792837 */
793838 elementWiseMulBroadcast ( a : Array2D , b : Array2D ) : Array2D {
794839 util . assert (
@@ -799,10 +844,8 @@ export abstract class NDArrayMath {
799844 b . rank === 2 ,
800845 `Error in elementWiseMulBroadcast: second argument must be ` +
801846 `rank 2, but got rank ${ b . rank } .` ) ;
802- return this . track ( this . elementWiseMulBroadcastInternal ( a , b ) ) ;
847+ return this . multiply ( a , b ) as Array2D ;
803848 }
804- protected abstract elementWiseMulBroadcastInternal ( a : Array2D , b : Array2D ) :
805- Array2D ;
806849
807850 /////////////////////
808851 // Convolution ops //
0 commit comments