diff --git a/arrayfire/library/linear_algebra.py b/arrayfire/library/linear_algebra.py index ee1f09e..67109f2 100644 --- a/arrayfire/library/linear_algebra.py +++ b/arrayfire/library/linear_algebra.py @@ -92,6 +92,7 @@ def gemm( rhs_opts: MatProp = MatProp.NONE, alpha: int | float = 1.0, beta: int | float = 0.0, + accum: Array = None ) -> Array: """ Performs BLAS general matrix multiplication (GEMM) on two Array instances. @@ -125,6 +126,10 @@ def gemm( beta : int | float, optional Scalar multiplier for the existing matrix C in the accumulation. Default is 0.0. + accum: Array, optional + A 2-dimensional, real or complex array representing the matrix C in the accumulation. + Default is None (no accumulation). + Returns ------- Array @@ -135,7 +140,10 @@ def gemm( - The data types of `lhs` and `rhs` must be compatible. - Batch operations are not supported in this version. """ - return cast(Array, wrapper.gemm(lhs.arr, rhs.arr, lhs_opts, rhs_opts, alpha, beta)) + accumulator = None + if isinstance(accum, Array): + accumulator = accum.arr + return cast(Array, wrapper.gemm(lhs.arr, rhs.arr, lhs_opts, rhs_opts, alpha, beta, accumulator)) @afarray_as_array diff --git a/tests/test_library/test_linear_algebra.py b/tests/test_library/test_linear_algebra.py index 19b1cb2..38d7074 100644 --- a/tests/test_library/test_linear_algebra.py +++ b/tests/test_library/test_linear_algebra.py @@ -65,8 +65,8 @@ def test_gemm_basic(matrix_a: af.Array, matrix_b: af.Array) -> None: def test_gemm_alpha_beta(matrix_a: af.Array, matrix_b: af.Array) -> None: alpha = 0.5 beta = 2.0 - result = af.gemm(matrix_a, matrix_b, alpha=alpha, beta=beta) - expected = create_from_2d_nested(10.5, 12.0, 22.5, 26.0) + result = af.gemm(matrix_a, matrix_b, alpha=alpha, beta=beta, accum=matrix_a) + expected = create_from_2d_nested(11.5, 15.0, 27.5, 33.0) assert result == expected, f"Expected {expected}, got {result}"