Skip to content

Commit 97bba2a

Browse files
aman-095kgryte
andauthored
feat!: add support for stacks in blas/sdot
BREAKING CHANGE: return an ndarray, rather than a scalar This commit changes the return value from a scalar to an ndarray. Previously, the function only operated on one-dimensional ndarrays and returned a scalar value. Now, the function operates on ndarrays of arbitrary rank and always returns an ndarray. In order to migrate, for one-dimensional input ndarrays, users should call `out.get()` in order to retrieve the scalar dot product. PR-URL: #2895 Co-authored-by: Athan Reines <[email protected]> Reviewed-by: Athan Reines <[email protected]>
1 parent e8fd916 commit 97bba2a

File tree

13 files changed

+811
-159
lines changed

13 files changed

+811
-159
lines changed

lib/node_modules/@stdlib/blas/ddot/lib/main.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ var format = require( '@stdlib/string/format' );
4141
*
4242
* @param {ndarrayLike} x - first input array
4343
* @param {ndarrayLike} y - second input array
44-
* @param {NegativeInteger} [dim] - dimension for which to compute the dot product
44+
* @param {NegativeInteger} [dim=-1] - dimension for which to compute the dot product
4545
* @throws {TypeError} first argument must be a ndarray containing double-precision floating-point numbers
4646
* @throws {TypeError} first argument must have at least one dimension
4747
* @throws {TypeError} second argument must be a ndarray containing double-precision floating-point numbers

lib/node_modules/@stdlib/blas/ddot/package.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
}
1515
],
1616
"main": "./lib",
17-
"browser": "./lib/main.js",
1817
"directories": {
1918
"benchmark": "./benchmark",
2019
"doc": "./docs",

lib/node_modules/@stdlib/blas/sdot/README.md

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ The [dot product][dot-product] (or scalar product) is defined as
3333
```
3434

3535
<!-- <div class="equation" align="center" data-raw-text="\mathbf{x}\cdot\mathbf{y} = \sum_{i=0}^{N-1} x_i y_i = x_0 y_0 + x_1 y_1 + \ldots + x_{N-1} y_{N-1}" data-equation="eq:dot_product">
36-
<img src="https://cdn.jsdelivr.net/gh/stdlib-js/stdlib@03fff24f5a7ba807a292f08cfef75ed0748e40de/lib/node_modules/@stdlib/blas/sdot/docs/img/equation_dot_product.svg" alt="Dot product definition.">
36+
<img src="https://cdn.jsdelivr.net/gh/stdlib-js/stdlib@d0afc603cdda35b11d5bd1633dd4dddb0d59e117/lib/node_modules/@stdlib/blas/sdot/docs/img/equation_dot_product.svg" alt="Dot product definition.">
3737
<br>
3838
</div> -->
3939

@@ -51,9 +51,9 @@ The [dot product][dot-product] (or scalar product) is defined as
5151
var sdot = require( '@stdlib/blas/sdot' );
5252
```
5353

54-
#### sdot( x, y )
54+
#### sdot( x, y\[, dim] )
5555

56-
Calculates the dot product of vectors `x` and `y`.
56+
Calculates the dot product of two single-precision floating-point vectors `x` and `y`.
5757

5858
```javascript
5959
var Float32Array = require( '@stdlib/array/float32' );
@@ -63,25 +63,38 @@ var x = array( new Float32Array( [ 4.0, 2.0, -3.0, 5.0, -1.0 ] ) );
6363
var y = array( new Float32Array( [ 2.0, 6.0, -1.0, -4.0, 8.0 ] ) );
6464

6565
var z = sdot( x, y );
66+
// returns <ndarray>
67+
68+
var v = z.get();
6669
// returns -5.0
6770
```
6871

6972
The function has the following parameters:
7073

71-
- **x**: a 1-dimensional [`ndarray`][@stdlib/ndarray/array] whose underlying data type is `float32`.
72-
- **y**: a 1-dimensional [`ndarray`][@stdlib/ndarray/array] whose underlying data type is `float32`.
74+
- **x**: a non-zero-dimensional [`ndarray`][@stdlib/ndarray/ctor] whose underlying data type is `float32`. Must be [broadcast-compatible][@stdlib/ndarray/base/broadcast-shapes] with `y`.
75+
- **y**: a non-zero-dimensional [`ndarray`][@stdlib/ndarray/ctor] whose underlying data type is `float32`. Must be [broadcast-compatible][@stdlib/ndarray/base/broadcast-shapes] with `x`.
76+
- **dim**: dimension for which to compute the dot product. Must be a negative integer. Negative indices are resolved relative to the last array dimension, with the last dimension corresponding to `-1`. Default: `-1`.
7377

74-
If provided empty vectors, the function returns `0.0`.
78+
If provided at least one input [`ndarray`][@stdlib/ndarray/ctor] having more than one dimension, the input [`ndarrays`][@stdlib/ndarray/ctor] are [broadcasted][@stdlib/ndarray/base/broadcast-shapes] to a common shape. For multi-dimensional input [`ndarrays`][@stdlib/ndarray/ctor], the function performs batched computation, such that the function computes the dot product for each pair of vectors in `x` and `y` according to the specified dimension index.
7579

7680
```javascript
7781
var Float32Array = require( '@stdlib/array/float32' );
7882
var array = require( '@stdlib/ndarray/array' );
7983

80-
var x = array( new Float32Array() );
81-
var y = array( new Float32Array() );
84+
var opts = {
85+
'shape': [ 2, 3 ]
86+
};
87+
var x = array( new Float32Array( [ 4.0, 2.0, -3.0, 5.0, -1.0, 3.0 ] ), opts );
88+
var y = array( new Float32Array( [ 2.0, 6.0, -1.0, -4.0, 8.0, 2.0 ] ), opts );
8289

8390
var z = sdot( x, y );
84-
// returns 0.0
91+
// returns <ndarray>
92+
93+
var v1 = z.get( 0 );
94+
// returns 23.0
95+
96+
var v2 = z.get( 1 );
97+
// returns -22.0
8598
```
8699

87100
</section>
@@ -92,6 +105,11 @@ var z = sdot( x, y );
92105

93106
## Notes
94107

108+
- The size of the contracted dimension must be the same for both input [`ndarrays`][@stdlib/ndarray/ctor].
109+
- The function resolves the dimension index for which to compute the dot product **before** broadcasting.
110+
- Negative indices are resolved relative to the last [`ndarray`][@stdlib/ndarray/ctor] dimension, with the last dimension corresponding to `-1`.
111+
- The output [`ndarray`][@stdlib/ndarray/ctor] has the same data type as the input [`ndarrays`][@stdlib/ndarray/ctor] and has a shape which is determined by broadcasting and excludes the contracted dimension.
112+
- If provided empty vectors, the dot product is `0`.
95113
- `sdot()` provides a higher-level interface to the [BLAS][blas] level 1 function [`sdot`][@stdlib/blas/base/sdot].
96114

97115
</section>
@@ -105,27 +123,27 @@ var z = sdot( x, y );
105123
<!-- eslint no-undef: "error" -->
106124

107125
```javascript
108-
var discreteUniform = require( '@stdlib/random/base/discrete-uniform' );
109-
var Float32Array = require( '@stdlib/array/float32' );
126+
var discreteUniform = require( '@stdlib/random/array/discrete-uniform' );
127+
var ndarray2array = require( '@stdlib/ndarray/to-array' );
110128
var array = require( '@stdlib/ndarray/array' );
111129
var sdot = require( '@stdlib/blas/sdot' );
112130

113-
var x = array( new Float32Array( 10 ) );
114-
var y = array( new Float32Array( 10 ) );
131+
var opts = {
132+
'dtype': 'float32'
133+
};
115134

116-
var rand1 = discreteUniform.factory( 0, 100 );
117-
var rand2 = discreteUniform.factory( 0, 10 );
135+
var x = array( discreteUniform( 10, 0, 100, opts ), {
136+
'shape': [ 5, 2 ]
137+
});
138+
console.log( ndarray2array( x ) );
118139

119-
var i;
120-
for ( i = 0; i < x.length; i++ ) {
121-
x.set( i, rand1() );
122-
y.set( i, rand2() );
123-
}
124-
console.log( x.toString() );
125-
console.log( y.toString() );
140+
var y = array( discreteUniform( 10, 0, 10, opts ), {
141+
'shape': x.shape
142+
});
143+
console.log( ndarray2array( y ) );
126144

127-
var z = sdot( x, y );
128-
console.log( z );
145+
var z = sdot( x, y, -1 );
146+
console.log( ndarray2array( z ) );
129147
```
130148

131149
</section>
@@ -136,14 +154,6 @@ console.log( z );
136154

137155
<section class="related">
138156

139-
* * *
140-
141-
## See Also
142-
143-
- <span class="package-name">[`@stdlib/blas/base/sdot`][@stdlib/blas/base/sdot]</span><span class="delimiter">: </span><span class="description">calculate the dot product of two single-precision floating-point vectors.</span>
144-
- <span class="package-name">[`@stdlib/blas/ddot`][@stdlib/blas/ddot]</span><span class="delimiter">: </span><span class="description">calculate the dot product of two double-precision floating-point vectors.</span>
145-
- <span class="package-name">[`@stdlib/blas/gdot`][@stdlib/blas/gdot]</span><span class="delimiter">: </span><span class="description">calculate the dot product of two vectors.</span>
146-
147157
</section>
148158

149159
<!-- /.related -->
@@ -156,18 +166,12 @@ console.log( z );
156166

157167
[blas]: http://www.netlib.org/blas
158168

159-
[@stdlib/ndarray/array]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/array
169+
[@stdlib/ndarray/ctor]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/ctor
160170

161-
<!-- <related-links> -->
171+
[@stdlib/ndarray/base/broadcast-shapes]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/base/broadcast-shapes
162172

163173
[@stdlib/blas/base/sdot]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/blas/base/sdot
164174

165-
[@stdlib/blas/ddot]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/blas/ddot
166-
167-
[@stdlib/blas/gdot]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/blas/gdot
168-
169-
<!-- </related-links> -->
170-
171175
</section>
172176

173177
<!-- /.links -->

lib/node_modules/@stdlib/blas/sdot/benchmark/benchmark.js

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,21 @@
2121
// MODULES //
2222

2323
var bench = require( '@stdlib/bench' );
24-
var randu = require( '@stdlib/random/base/randu' );
2524
var isnan = require( '@stdlib/math/base/assert/is-nan' );
2625
var pow = require( '@stdlib/math/base/special/pow' );
27-
var Float32Array = require( '@stdlib/array/float32' );
26+
var uniform = require( '@stdlib/random/array/uniform' );
2827
var array = require( '@stdlib/ndarray/array' );
2928
var pkg = require( './../package.json' ).name;
3029
var sdot = require( './../lib/main.js' );
3130

3231

32+
// VARIABLES //
33+
34+
var opts = {
35+
'dtype': 'float32'
36+
};
37+
38+
3339
// FUNCTIONS //
3440

3541
/**
@@ -40,34 +46,29 @@ var sdot = require( './../lib/main.js' );
4046
* @returns {Function} benchmark function
4147
*/
4248
function createBenchmark( len ) {
43-
var x;
44-
var y;
45-
var i;
46-
47-
x = new Float32Array( len );
48-
y = new Float32Array( len );
49-
for ( i = 0; i < len; i++ ) {
50-
x[ i ] = ( randu()*10.0 ) - 20.0;
51-
y[ i ] = ( randu()*10.0 ) - 20.0;
52-
}
53-
x = array( x );
54-
y = array( y );
55-
49+
var x = array( uniform( len, -100.0, 100.0, opts ) );
50+
var y = array( uniform( len, -100.0, 100.0, opts ) );
5651
return benchmark;
5752

53+
/**
54+
* Benchmark function.
55+
*
56+
* @private
57+
* @param {Benchmark} b - benchmark instance
58+
*/
5859
function benchmark( b ) {
5960
var d;
6061
var i;
6162

6263
b.tic();
6364
for ( i = 0; i < b.iterations; i++ ) {
6465
d = sdot( x, y );
65-
if ( isnan( d ) ) {
66+
if ( isnan( d.get() ) ) {
6667
b.fail( 'should not return NaN' );
6768
}
6869
}
6970
b.toc();
70-
if ( isnan( d ) ) {
71+
if ( isnan( d.get() ) ) {
7172
b.fail( 'should not return NaN' );
7273
}
7374
b.pass( 'benchmark finished' );
@@ -96,7 +97,7 @@ function main() {
9697
for ( i = min; i <= max; i++ ) {
9798
len = pow( 10, i );
9899
f = createBenchmark( len );
99-
bench( pkg+':len='+len, f );
100+
bench( pkg+'::vectors:len='+len, f );
100101
}
101102
}
102103

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2020 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
'use strict';
20+
21+
// MODULES //
22+
23+
var bench = require( '@stdlib/bench' );
24+
var isnan = require( '@stdlib/math/base/assert/is-nan' );
25+
var pow = require( '@stdlib/math/base/special/pow' );
26+
var uniform = require( '@stdlib/random/array/uniform' );
27+
var numel = require( '@stdlib/ndarray/base/numel' );
28+
var array = require( '@stdlib/ndarray/array' );
29+
var pkg = require( './../package.json' ).name;
30+
var sdot = require( './../lib/main.js' );
31+
32+
33+
// VARIABLES //
34+
35+
var OPTS = {
36+
'dtype': 'float32'
37+
};
38+
39+
40+
// FUNCTIONS //
41+
42+
/**
43+
* Creates a benchmark function.
44+
*
45+
* @private
46+
* @param {PositiveIntegerArray} shape - array shape
47+
* @returns {Function} benchmark function
48+
*/
49+
function createBenchmark( shape ) {
50+
var x;
51+
var y;
52+
var N;
53+
var o;
54+
55+
N = numel( shape );
56+
o = {
57+
'shape': shape
58+
};
59+
x = array( uniform( N, -100.0, 100.0, OPTS ), o );
60+
y = array( uniform( N, -100.0, 100.0, OPTS ), o );
61+
62+
return benchmark;
63+
64+
/**
65+
* Benchmark function.
66+
*
67+
* @private
68+
* @param {Benchmark} b - benchmark instance
69+
*/
70+
function benchmark( b ) {
71+
var d;
72+
var i;
73+
74+
b.tic();
75+
for ( i = 0; i < b.iterations; i++ ) {
76+
d = sdot( x, y );
77+
if ( isnan( d.iget( 0 ) ) ) {
78+
b.fail( 'should not return NaN' );
79+
}
80+
}
81+
b.toc();
82+
if ( isnan( d.iget( 0 ) ) ) {
83+
b.fail( 'should not return NaN' );
84+
}
85+
b.pass( 'benchmark finished' );
86+
b.end();
87+
}
88+
}
89+
90+
91+
// MAIN //
92+
93+
/**
94+
* Main execution sequence.
95+
*
96+
* @private
97+
*/
98+
function main() {
99+
var shape;
100+
var min;
101+
var max;
102+
var N;
103+
var f;
104+
var i;
105+
106+
min = 1; // 10^min
107+
max = 6; // 10^max
108+
109+
for ( i = min; i <= max; i++ ) {
110+
N = pow( 10, i );
111+
112+
shape = [ 2, N/2 ];
113+
f = createBenchmark( shape );
114+
bench( pkg+'::stacks:size='+N+',ndims='+shape.length+',shape=('+shape.join( ',' )+')', f );
115+
116+
shape = [ N/2, 2 ];
117+
f = createBenchmark( shape );
118+
bench( pkg+'::stacks:size='+N+',ndims='+shape.length+',shape=('+shape.join( ',' )+')', f );
119+
}
120+
}
121+
122+
main();

0 commit comments

Comments
 (0)