Skip to content

Commit 712a9a8

Browse files
committed
fix: compute singleton strides as if strides computed apriori
--- type: pre_commit_static_analysis_report description: Results of running static analysis checks when committing changes. report: - task: lint_filenames status: passed - task: lint_editorconfig status: passed - task: lint_markdown status: na - task: lint_package_json status: na - task: lint_repl_help status: na - task: lint_javascript_src status: passed - task: lint_javascript_cli status: na - task: lint_javascript_examples status: na - task: lint_javascript_tests status: passed - task: lint_javascript_benchmarks status: na - task: lint_python status: na - task: lint_r status: na - task: lint_c_src status: na - task: lint_c_examples status: na - task: lint_c_benchmarks status: na - task: lint_c_tests_fixtures status: na - task: lint_shell status: na - task: lint_typescript_declarations status: na - task: lint_typescript_tests status: na - task: lint_license_headers status: passed ---
1 parent c634089 commit 712a9a8

File tree

2 files changed

+19
-8
lines changed
  • lib/node_modules/@stdlib/ndarray/base/expand-dimensions

2 files changed

+19
-8
lines changed

lib/node_modules/@stdlib/ndarray/base/expand-dimensions/lib/main.js

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
// MODULES //
2222

23+
var isRowMajor = require( '@stdlib/ndarray/base/assert/is-row-major-string' );
2324
var isReadOnly = require( '@stdlib/ndarray/base/assert/is-read-only' );
2425
var normalizeIndex = require( '@stdlib/ndarray/base/normalize-index' );
2526
var getDType = require( '@stdlib/ndarray/base/dtype' );
@@ -75,6 +76,7 @@ var format = require( '@stdlib/string/format' );
7576
function expandDimensions( x, axis ) {
7677
var strides;
7778
var shape;
79+
var isrm;
7880
var ord;
7981
var sh;
8082
var st;
@@ -85,6 +87,8 @@ function expandDimensions( x, axis ) {
8587
sh = getShape( x, false );
8688
st = getStrides( x, false );
8789
ord = getOrder( x );
90+
91+
isrm = isRowMajor( ord );
8892
N = sh.length;
8993

9094
strides = [];
@@ -97,8 +101,11 @@ function expandDimensions( x, axis ) {
97101
if ( d === 0 ) {
98102
// Prepend singleton dimension...
99103
shape.push( 1 );
100-
strides.push( st[ 0 ] );
101-
104+
if ( isrm ) {
105+
strides.push( sh[ 0 ] * st[ 0 ] );
106+
} else {
107+
strides.push( st[ 0 ] );
108+
}
102109
// Copy remaining dimensions...
103110
for ( i = 0; i < N; i++ ) {
104111
shape.push( sh[ i ] );
@@ -112,13 +119,17 @@ function expandDimensions( x, axis ) {
112119
}
113120
// Append singleton dimension...
114121
shape.push( 1 );
115-
strides.push( st[ N-1 ] );
122+
if ( isrm ) {
123+
strides.push( st[ N-1 ] );
124+
} else {
125+
strides.push( sh[ N-1 ] * st[ N-1 ] );
126+
}
116127
} else {
117128
// Insert a singleton dimension...
118129
for ( i = 0; i < N+1; i++ ) {
119130
if ( i === d ) {
120131
shape.push( 1 );
121-
if ( ord === 'row-major' ) {
132+
if ( isrm ) {
122133
strides.push( st[ i-1 ] );
123134
} else { // ord === 'column-major'
124135
strides.push( st[ i ] );

lib/node_modules/@stdlib/ndarray/base/expand-dimensions/test/test.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ tape( 'the function prepends singleton dimensions (base; row-major)', function t
159159

160160
t.notEqual( y, x, 'returns expected value' );
161161
t.deepEqual( y.shape, [ 1, 2, 2 ], 'returns expected value' );
162-
t.deepEqual( y.strides, [ 2, 2, 1 ], 'returns expected value' );
162+
t.deepEqual( y.strides, [ 4, 2, 1 ], 'returns expected value' );
163163
t.strictEqual( y.data, x.data, 'returns expected value' );
164164

165165
t.end();
@@ -174,7 +174,7 @@ tape( 'the function prepends singleton dimensions (base; row-major; negative axi
174174

175175
t.notEqual( y, x, 'returns expected value' );
176176
t.deepEqual( y.shape, [ 1, 2, 2 ], 'returns expected value' );
177-
t.deepEqual( y.strides, [ 2, 2, 1 ], 'returns expected value' );
177+
t.deepEqual( y.strides, [ 4, 2, 1 ], 'returns expected value' );
178178
t.strictEqual( y.data, x.data, 'returns expected value' );
179179

180180
t.end();
@@ -249,7 +249,7 @@ tape( 'the function appends singleton dimensions (base; column-major)', function
249249

250250
t.notEqual( y, x, 'returns expected value' );
251251
t.deepEqual( y.shape, [ 2, 2, 1 ], 'returns expected value' );
252-
t.deepEqual( y.strides, [ 1, 2, 2 ], 'returns expected value' );
252+
t.deepEqual( y.strides, [ 1, 2, 4 ], 'returns expected value' );
253253
t.strictEqual( y.data, x.data, 'returns expected value' );
254254

255255
t.end();
@@ -264,7 +264,7 @@ tape( 'the function appends singleton dimensions (base; column-major; negative i
264264

265265
t.notEqual( y, x, 'returns expected value' );
266266
t.deepEqual( y.shape, [ 2, 2, 1 ], 'returns expected value' );
267-
t.deepEqual( y.strides, [ 1, 2, 2 ], 'returns expected value' );
267+
t.deepEqual( y.strides, [ 1, 2, 4 ], 'returns expected value' );
268268
t.strictEqual( y.data, x.data, 'returns expected value' );
269269

270270
t.end();

0 commit comments

Comments
 (0)