@@ -69,37 +69,4 @@ def blk_concat(
69
69
70
70
c = blk_concat (a )
71
71
print ("Concatenated matrix:" )
72
- print (c )
73
-
74
- def compute_weighted_sums (M : Array , vecm : Array , idx : int ) -> Array :
75
- """
76
- Compute the weighted sums of the matrix product of M and vecm,
77
-
78
- Args:
79
- M (Array): array of shape (N, m, m)
80
- Describes the matrix to be multiplied with vecm
81
- vecm (Array): array-like of shape (N, m)
82
- Describes the vector to be multiplied with M
83
- idx (int): index of the last row to be summed over
84
-
85
- Returns:
86
- Array: array of shape (N, m)
87
- The result of the weighted sums. For each i, the result is the sum of the products of M[i, j] and vecm[j] for j from 0 to idx.
88
- """
89
- N = M .shape [0 ]
90
- # Matrix product for each j: (N, m, m) @ (N, m, 1) -> (N, m)
91
- prod = jnp .einsum ("nij,nj->ni" , M , vecm )
92
-
93
- # Triangular mask for partial sum: (N, N)
94
- # mask[i, j] = 1 if j >= i and j <= idx
95
- mask = (jnp .arange (N )[:, None ] <= jnp .arange (N )[None , :]) & (
96
- jnp .arange (N )[None , :] <= idx
97
- )
98
- mask = mask .astype (M .dtype ) # (N, N)
99
-
100
- # Extend 6-dimensional mask (N, N, 1) to apply to (N, m)
101
- masked_prod = mask [:, :, None ] * prod [None , :, :] # (N, N, m)
102
-
103
- # Sum over j for each i : (N, m)
104
- result = masked_prod .sum (axis = 1 ) # (N, m)
105
- return result
72
+ print (c )
0 commit comments