Skip to content

Commit de6b9d4

Browse files
committed
Refactor to avoid unwrap
1 parent b0c75b5 commit de6b9d4

File tree

1 file changed

+139
-130
lines changed

1 file changed

+139
-130
lines changed

src/operations.rs

Lines changed: 139 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -53,48 +53,49 @@ fn count_array_multi_axis<T: Element>(
5353
axes: &[usize],
5454
missing: Option<Missing<T>>,
5555
) -> (Vec<i64>, Vec<usize>) {
56-
let result = if axes.is_empty() {
57-
// Emulate numpy semantics of axis = () being
58-
// equivalent to a 'reduction over no axes'
59-
array.map(|val| {
60-
if let Some(missing) = &missing {
61-
if !missing.is_missing(val) {
62-
1
63-
} else {
64-
0
65-
}
66-
} else {
67-
1
68-
}
69-
})
70-
} else {
71-
// Should never panic here due to axis.is_empty() branch above
72-
let first_axis = axes.first().expect("axes list to be non-empty");
73-
// Count non-missing over first axis
74-
let mut result = array
75-
.fold_axis(Axis(*first_axis), 0, |running_count, val| {
56+
let result = match axes.first() {
57+
None => {
58+
// Emulate numpy semantics of axis = () being
59+
// equivalent to a 'reduction over no axes'
60+
array.map(|val| {
7661
if let Some(missing) = &missing {
7762
if !missing.is_missing(val) {
78-
running_count + 1
63+
1
7964
} else {
80-
*running_count
65+
0
8166
}
8267
} else {
83-
running_count + 1
68+
1
8469
}
8570
})
86-
.into_dyn();
87-
// Sum counts over remaining axes
88-
if let Some(remaining_axes) = axes.get(1..) {
89-
for (n, axis) in remaining_axes.iter().enumerate() {
90-
result = result
91-
.fold_axis(Axis(axis - n - 1), 0, |total_count, count| {
92-
total_count + count
93-
})
94-
.into_dyn();
71+
}
72+
Some(first_axis) => {
73+
// Count non-missing over first axis
74+
let mut result = array
75+
.fold_axis(Axis(*first_axis), 0, |running_count, val| {
76+
if let Some(missing) = &missing {
77+
if !missing.is_missing(val) {
78+
running_count + 1
79+
} else {
80+
*running_count
81+
}
82+
} else {
83+
running_count + 1
84+
}
85+
})
86+
.into_dyn();
87+
// Sum counts over remaining axes
88+
if let Some(remaining_axes) = axes.get(1..) {
89+
for (n, axis) in remaining_axes.iter().enumerate() {
90+
result = result
91+
.fold_axis(Axis(axis - n - 1), 0, |total_count, count| {
92+
total_count + count
93+
})
94+
.into_dyn();
95+
}
9596
}
97+
result
9698
}
97-
result
9899
};
99100

100101
// Convert result to owned vec
@@ -233,45 +234,48 @@ fn max_array_multi_axis<T: Element>(
233234
missing: Option<Missing<T>>,
234235
order: &Option<Order>,
235236
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
236-
let (result, shape) = if axes.is_empty() {
237-
// Emulate numpy behaviour of 'reduction over no axes'
238-
let result = reduction_over_zero_axes(&array, missing, order);
239-
(result, array.shape().to_owned())
240-
} else {
241-
// Find maximum over first axis and count elements operated on
242-
let init = T::min_value();
243-
let mut result = array
244-
.fold_axis(Axis(axes[0]), (init, 0), |(running_max, count), val| {
245-
if let Some(missing) = &missing {
246-
if !missing.is_missing(val) {
237+
let (result, shape) = match axes.first() {
238+
None => {
239+
// Emulate numpy behaviour of 'reduction over no axes'
240+
let result = reduction_over_zero_axes(&array, missing, order);
241+
(result, array.shape().to_owned())
242+
}
243+
Some(first_axis) => {
244+
// Find maximum over first axis and count elements operated on
245+
let init = T::min_value();
246+
let mut result = array
247+
.fold_axis(Axis(*first_axis), (init, 0), |(running_max, count), val| {
248+
if let Some(missing) = &missing {
249+
if !missing.is_missing(val) {
250+
let new_max = max_by(running_max, val, max_element_pairwise);
251+
(*new_max, count + 1)
252+
} else {
253+
(*running_max, *count)
254+
}
255+
} else {
247256
let new_max = max_by(running_max, val, max_element_pairwise);
248257
(*new_max, count + 1)
249-
} else {
250-
(*running_max, *count)
251258
}
252-
} else {
253-
let new_max = max_by(running_max, val, max_element_pairwise);
254-
(*new_max, count + 1)
259+
})
260+
.into_dyn();
261+
// Find max over remaining axes (where total count is now sum of counts)
262+
if let Some(remaining_axes) = axes.get(1..) {
263+
for (n, axis) in remaining_axes.iter().enumerate() {
264+
result = result
265+
.fold_axis(
266+
Axis(axis - n - 1),
267+
(init, 0),
268+
|(global_max, total_count), (running_max, count)| {
269+
let new_max = max_by(global_max, running_max, max_element_pairwise);
270+
(*new_max, total_count + count)
271+
},
272+
)
273+
.into_dyn();
255274
}
256-
})
257-
.into_dyn();
258-
// Find max over remaining axes (where total count is now sum of counts)
259-
if let Some(remaining_axes) = axes.get(1..) {
260-
for (n, axis) in remaining_axes.iter().enumerate() {
261-
result = result
262-
.fold_axis(
263-
Axis(axis - n - 1),
264-
(init, 0),
265-
|(global_max, total_count), (running_max, count)| {
266-
let new_max = max_by(global_max, running_max, max_element_pairwise);
267-
(*new_max, total_count + count)
268-
},
269-
)
270-
.into_dyn();
271275
}
276+
let shape = result.shape().to_owned();
277+
(result, shape)
272278
}
273-
let shape = result.shape().to_owned();
274-
(result, shape)
275279
};
276280

277281
// Result is array of (max, count) tuples so separate them here
@@ -391,46 +395,48 @@ fn min_array_multi_axis<T: Element>(
391395
missing: Option<Missing<T>>,
392396
order: &Option<Order>,
393397
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
394-
let (result, shape) = if axes.is_empty() {
395-
// Emulate numpy behaviour of 'reduction over no axes'
396-
let result = reduction_over_zero_axes(&array, missing, order);
397-
(result, array.shape().to_owned())
398-
} else {
399-
// Find minimum over first axis and count elements operated on
400-
let init = T::max_value();
401-
let mut result = array
402-
.fold_axis(Axis(axes[0]), (init, 0), |(running_min, count), val| {
403-
if let Some(missing) = &missing {
404-
if !missing.is_missing(val) {
398+
let (result, shape) = match axes.first() {
399+
None => {
400+
// Emulate numpy behaviour of 'reduction over no axes'
401+
let result = reduction_over_zero_axes(&array, missing, order);
402+
(result, array.shape().to_owned())
403+
}
404+
Some(first_axis) => {
405+
// Find minimum over first axis and count elements operated on
406+
let init = T::max_value();
407+
let mut result = array
408+
.fold_axis(Axis(*first_axis), (init, 0), |(running_min, count), val| {
409+
if let Some(missing) = &missing {
410+
if !missing.is_missing(val) {
411+
let new_min = min_by(running_min, val, min_element_pairwise);
412+
(*new_min, count + 1)
413+
} else {
414+
(*running_min, *count)
415+
}
416+
} else {
405417
let new_min = min_by(running_min, val, min_element_pairwise);
406418
(*new_min, count + 1)
407-
} else {
408-
(*running_min, *count)
409419
}
410-
} else {
411-
let new_min = min_by(running_min, val, min_element_pairwise);
412-
(*new_min, count + 1)
420+
})
421+
.into_dyn();
422+
// Find min over remaining axes (where total count is now sum of counts)
423+
if let Some(remaining_axes) = axes.get(1..) {
424+
for (n, axis) in remaining_axes.iter().enumerate() {
425+
result = result
426+
.fold_axis(
427+
Axis(axis - n - 1),
428+
(init, 0),
429+
|(global_min, total_count), (running_min, count)| {
430+
let new_min = min_by(global_min, running_min, min_element_pairwise);
431+
(*new_min, total_count + count)
432+
},
433+
)
434+
.into_dyn();
413435
}
414-
})
415-
.into_dyn();
416-
// Find min over remaining axes (where total count is now sum of counts)
417-
if let Some(remaining_axes) = axes.get(1..) {
418-
for (n, axis) in remaining_axes.iter().enumerate() {
419-
result = result
420-
.fold_axis(
421-
Axis(axis - n - 1),
422-
(init, 0),
423-
|(global_min, total_count), (running_min, count)| {
424-
// (*global_min.min(running_min), total_count + count)
425-
let new_min = min_by(global_min, running_min, min_element_pairwise);
426-
(*new_min, total_count + count)
427-
},
428-
)
429-
.into_dyn();
430436
}
437+
let shape = result.shape().to_owned();
438+
(result, shape)
431439
}
432-
let shape = result.shape().to_owned();
433-
(result, shape)
434440
};
435441

436442
// Result is array of (mins, count) tuples so separate them here
@@ -574,41 +580,44 @@ fn sum_array_multi_axis<T: Element>(
574580
missing: Option<Missing<T>>,
575581
order: &Option<Order>,
576582
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
577-
let (result, shape) = if axes.is_empty() {
578-
// Emulate numpy behaviour of 'reduction over no axes'
579-
let result = reduction_over_zero_axes(&array, missing, order);
580-
(result, array.shape().to_owned())
581-
} else {
582-
// Sum over first axis and count elements operated on
583-
let mut result = array
584-
.fold_axis(Axis(axes[0]), (T::zero(), 0), |(sum, count), val| {
585-
if let Some(missing) = &missing {
586-
if !missing.is_missing(val) {
587-
(*sum + *val, count + 1)
583+
let (result, shape) = match axes.first() {
584+
None => {
585+
// Emulate numpy behaviour of 'reduction over no axes'
586+
let result = reduction_over_zero_axes(&array, missing, order);
587+
(result, array.shape().to_owned())
588+
}
589+
Some(first_axis) => {
590+
// Sum over first axis and count elements operated on
591+
let mut result = array
592+
.fold_axis(Axis(*first_axis), (T::zero(), 0), |(sum, count), val| {
593+
if let Some(missing) = &missing {
594+
if !missing.is_missing(val) {
595+
(*sum + *val, count + 1)
596+
} else {
597+
(*sum, *count)
598+
}
588599
} else {
589-
(*sum, *count)
600+
(*sum + *val, count + 1)
590601
}
591-
} else {
592-
(*sum + *val, count + 1)
602+
})
603+
.into_dyn();
604+
// Sum over remaining axes (where total count is now sum of counts)
605+
if let Some(remaining_axes) = axes.get(1..) {
606+
for (n, axis) in remaining_axes.iter().enumerate() {
607+
result = result
608+
.fold_axis(
609+
Axis(axis - n - 1),
610+
(T::zero(), 0),
611+
|(total_sum, total_count), (sum, count)| {
612+
(*total_sum + *sum, total_count + count)
613+
},
614+
)
615+
.into_dyn();
593616
}
594-
})
595-
.into_dyn();
596-
// Sum over remaining axes (where total count is now sum of counts)
597-
if let Some(remaining_axes) = axes.get(1..) {
598-
for (n, axis) in remaining_axes.iter().enumerate() {
599-
result = result
600-
.fold_axis(
601-
Axis(axis - n - 1),
602-
(T::zero(), 0),
603-
|(total_sum, total_count), (sum, count)| {
604-
(*total_sum + *sum, total_count + count)
605-
},
606-
)
607-
.into_dyn();
608617
}
618+
let shape = result.shape().to_owned();
619+
(result, shape)
609620
}
610-
let shape = result.shape().to_owned();
611-
(result, shape)
612621
};
613622

614623
// Result is array of (sum, count) tuples so separate them here

0 commit comments

Comments
 (0)