Skip to content

Commit a3f25ad

Browse files
committed
feat: add blas/tools/dot-factory
1 parent 982980c commit a3f25ad

File tree

11 files changed

+1922
-0
lines changed

11 files changed

+1922
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
<!--
2+
3+
@license Apache-2.0
4+
5+
Copyright (c) 2024 The Stdlib Authors.
6+
7+
Licensed under the Apache License, Version 2.0 (the "License");
8+
you may not use this file except in compliance with the License.
9+
You may obtain a copy of the License at
10+
11+
http://www.apache.org/licenses/LICENSE-2.0
12+
13+
Unless required by applicable law or agreed to in writing, software
14+
distributed under the License is distributed on an "AS IS" BASIS,
15+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
See the License for the specific language governing permissions and
17+
limitations under the License.
18+
19+
-->
20+
21+
# factory
22+
23+
> Return a function which computes the dot product.
24+
25+
<section class="intro">
26+
27+
</section>
28+
29+
<!-- /.intro -->
30+
31+
<section class="usage">
32+
33+
## Usage
34+
35+
```javascript
36+
var factory = require( '@stdlib/blas/tools/dot-factory' );
37+
```
38+
39+
#### factory( base, dtype )
40+
41+
Returns a function which computes the dot product.
42+
43+
```javascript
44+
var ddot = require( '@stdlib/blas/base/ddot' ).ndarray;
45+
46+
var dot = factory( ddot, 'float64' );
47+
```
48+
49+
The function has the following parameters:
50+
51+
- **base**: "base" function which computes the dot product. Must have an `ndarray` function signature (i.e., must support index offsets).
52+
- **dtype**: array data type. The function assumes that the data type of all provided arrays is the same.
53+
54+
#### dot( x, y\[, dim] )
55+
56+
Computes the dot product of two floating-point vectors.
57+
58+
```javascript
59+
var Float64Array = require( '@stdlib/array/float64' );
60+
var array = require( '@stdlib/ndarray/array' );
61+
var ddot = require( '@stdlib/blas/base/ddot' ).ndarray;
62+
63+
var dot = factory( ddot, 'float64' );
64+
65+
var x = array( new Float64Array( [ 4.0, 2.0, -3.0, 5.0, -1.0 ] ) );
66+
var y = array( new Float64Array( [ 2.0, 6.0, -1.0, -4.0, 8.0 ] ) );
67+
68+
var z = dot( x, y );
69+
// returns <ndarray>
70+
71+
var v = z.get();
72+
// returns -5.0
73+
```
74+
75+
The returned function has the following parameters:
76+
77+
- **x**: a non-zero-dimensional [`ndarray`][@stdlib/ndarray/ctor]. Must be [broadcast-compatible][@stdlib/ndarray/base/broadcast-shapes] with `y`.
78+
- **y**: a non-zero-dimensional [`ndarray`][@stdlib/ndarray/ctor]. Must be [broadcast-compatible][@stdlib/ndarray/base/broadcast-shapes] with `x`.
79+
- **dim**: dimension for which to compute the dot product. Must be a negative integer. Negative indices are resolved relative to the last array dimension, with the last dimension corresponding to `-1`. Default: `-1`.
80+
81+
If provided at least one input [`ndarray`][@stdlib/ndarray/ctor] having more than one dimension, the input [`ndarrays`][@stdlib/ndarray/ctor] are [broadcasted][@stdlib/ndarray/base/broadcast-shapes] to a common shape. For multi-dimensional input [`ndarrays`][@stdlib/ndarray/ctor], the function performs batched computation, such that the function computes the dot product for each pair of vectors in `x` and `y` according to the specified dimension index.
82+
83+
```javascript
84+
var Float64Array = require( '@stdlib/array/float64' );
85+
var array = require( '@stdlib/ndarray/array' );
86+
var ddot = require( '@stdlib/blas/base/ddot' ).ndarray;
87+
88+
var dot = factory( ddot, 'float64' );
89+
90+
var opts = {
91+
'shape': [ 2, 3 ]
92+
};
93+
var x = array( new Float64Array( [ 4.0, 2.0, -3.0, 5.0, -1.0, 3.0 ] ), opts );
94+
var y = array( new Float64Array( [ 2.0, 6.0, -1.0, -4.0, 8.0, 2.0 ] ), opts );
95+
96+
var z = dot( x, y );
97+
// returns <ndarray>
98+
99+
var v1 = z.get( 0 );
100+
// returns 23.0
101+
102+
var v2 = z.get( 1 );
103+
// returns -22.0
104+
```
105+
106+
</section>
107+
108+
<!-- /.usage -->
109+
110+
<section class="notes">
111+
112+
## Notes
113+
114+
For the returned function,
115+
116+
- The size of the contracted dimension must be the same for both input [`ndarrays`][@stdlib/ndarray/ctor].
117+
- The function resolves the dimension index for which to compute the dot product **before** broadcasting.
118+
- Negative indices are resolved relative to the last [`ndarray`][@stdlib/ndarray/ctor] dimension, with the last dimension corresponding to `-1`.
119+
- The output [`ndarray`][@stdlib/ndarray/ctor] has the same data type as the input [`ndarrays`][@stdlib/ndarray/ctor] and has a shape which is determined by broadcasting and excludes the contracted dimension.
120+
- If provided empty vectors, the dot product is `0`.
121+
122+
</section>
123+
124+
<!-- /.notes -->
125+
126+
<section class="examples">
127+
128+
## Examples
129+
130+
<!-- eslint no-undef: "error" -->
131+
132+
```javascript
133+
var discreteUniform = require( '@stdlib/random/array/discrete-uniform' );
134+
var ndarray2array = require( '@stdlib/ndarray/to-array' );
135+
var array = require( '@stdlib/ndarray/array' );
136+
var ddot = require( '@stdlib/blas/base/ddot' ).ndarray;
137+
var factory = require( '@stdlib/blas/tools/dot-factory' );
138+
139+
var dot = factory( ddot, 'float64' );
140+
141+
var opts = {
142+
'dtype': 'float64'
143+
};
144+
145+
var x = array( discreteUniform( 10, 0, 100, opts ), {
146+
'shape': [ 5, 2 ]
147+
});
148+
console.log( ndarray2array( x ) );
149+
150+
var y = array( discreteUniform( 10, 0, 10, opts ), {
151+
'shape': x.shape
152+
});
153+
console.log( ndarray2array( y ) );
154+
155+
var z = dot( x, y, -1 );
156+
console.log( ndarray2array( z ) );
157+
```
158+
159+
</section>
160+
161+
<!-- /.examples -->
162+
163+
<!-- Section for related `stdlib` packages. Do not manually edit this section, as it is automatically populated. -->
164+
165+
<section class="related">
166+
167+
</section>
168+
169+
<!-- /.related -->
170+
171+
<!-- Section for all links. Make sure to keep an empty line after the `section` element and another before the `/section` close. -->
172+
173+
<section class="links">
174+
175+
[@stdlib/ndarray/ctor]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/ctor
176+
177+
[@stdlib/ndarray/base/broadcast-shapes]: https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/ndarray/base/broadcast-shapes
178+
179+
</section>
180+
181+
<!-- /.links -->
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2024 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
'use strict';
20+
21+
// MODULES //
22+
23+
var bench = require( '@stdlib/bench' );
24+
var isnan = require( '@stdlib/math/base/assert/is-nan' );
25+
var pow = require( '@stdlib/math/base/special/pow' );
26+
var uniform = require( '@stdlib/random/array/uniform' );
27+
var array = require( '@stdlib/ndarray/array' );
28+
var ddot = require( '@stdlib/blas/base/ddot' ).ndarray;
29+
var pkg = require( './../package.json' ).name;
30+
var factory = require( './../lib' );
31+
32+
33+
// VARIABLES //
34+
35+
var opts = {
36+
'dtype': 'float64'
37+
};
38+
var dot = factory( ddot, opts.dtype );
39+
40+
41+
// FUNCTIONS //
42+
43+
/**
44+
* Creates a benchmark function.
45+
*
46+
* @private
47+
* @param {PositiveInteger} len - array length
48+
* @returns {Function} benchmark function
49+
*/
50+
function createBenchmark( len ) {
51+
var x = array( uniform( len, -100.0, 100.0, opts ) );
52+
var y = array( uniform( len, -100.0, 100.0, opts ) );
53+
return benchmark;
54+
55+
/**
56+
* Benchmark function.
57+
*
58+
* @private
59+
* @param {Benchmark} b - benchmark instance
60+
*/
61+
function benchmark( b ) {
62+
var d;
63+
var i;
64+
65+
b.tic();
66+
for ( i = 0; i < b.iterations; i++ ) {
67+
d = dot( x, y );
68+
if ( isnan( d.get() ) ) {
69+
b.fail( 'should not return NaN' );
70+
}
71+
}
72+
b.toc();
73+
if ( isnan( d.get() ) ) {
74+
b.fail( 'should not return NaN' );
75+
}
76+
b.pass( 'benchmark finished' );
77+
b.end();
78+
}
79+
}
80+
81+
82+
// MAIN //
83+
84+
/**
85+
* Main execution sequence.
86+
*
87+
* @private
88+
*/
89+
function main() {
90+
var len;
91+
var min;
92+
var max;
93+
var f;
94+
var i;
95+
96+
min = 1; // 10^min
97+
max = 6; // 10^max
98+
99+
for ( i = min; i <= max; i++ ) {
100+
len = pow( 10, i );
101+
f = createBenchmark( len );
102+
bench( pkg+'::vectors:len='+len, f );
103+
}
104+
}
105+
106+
main();

0 commit comments

Comments
 (0)