|
| 1 | +use core::f32; |
1 | 2 | use std::{fmt::Display, sync::LazyLock}; |
2 | 3 |
|
3 | 4 | use crate::{self as cubecl, as_type}; |
@@ -30,7 +31,11 @@ pub(crate) fn assert_equals_approx< |
30 | 31 | // account for lower precision at higher values |
31 | 32 | let allowed_error = F::new((epsilon * e.to_f32().unwrap().abs()).max(epsilon)); |
32 | 33 | assert!( |
33 | | - (*a - *e).abs() < allowed_error || (a.is_nan() && e.is_nan()), |
| 34 | + (*a - *e).abs() < allowed_error |
| 35 | + || (a.is_nan() && e.is_nan()) |
| 36 | + || (a.is_infinite() |
| 37 | + && e.is_infinite() |
| 38 | + && a.is_sign_positive() == e.is_sign_positive()), |
34 | 39 | "Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={} |
35 | 40 | index: {} |
36 | 41 | actual: {:?} |
@@ -184,6 +189,64 @@ test_binary_impl!( |
184 | 189 | ] |
185 | 190 | ); |
186 | 191 |
|
| 192 | +test_binary_impl!( |
| 193 | + test_hypot, |
| 194 | + F, |
| 195 | + F::hypot, |
| 196 | + [ |
| 197 | + { |
| 198 | + input_vectorization: 1, |
| 199 | + out_vectorization: 1, |
| 200 | + lhs: as_type![F: 3., 0., 5., 0.], |
| 201 | + rhs: as_type![F: 4., 5., 0., 0.], |
| 202 | + expected: as_type![F: 5., 5., 5., 0.] |
| 203 | + }, |
| 204 | + { |
| 205 | + input_vectorization: 2, |
| 206 | + out_vectorization: 2, |
| 207 | + lhs: as_type![F: 3., 0., 5., 8.], |
| 208 | + rhs: as_type![F: 4., 5., 0., 15.], |
| 209 | + expected: as_type![F: 5., 5., 5., 17.] |
| 210 | + }, |
| 211 | + { |
| 212 | + input_vectorization: 4, |
| 213 | + out_vectorization: 4, |
| 214 | + lhs: as_type![F: -3., 0., -5., -8.], |
| 215 | + rhs: as_type![F: -4., -5., 0., 15.], |
| 216 | + expected: as_type![F: 5., 5., 5., 17.] |
| 217 | + } |
| 218 | + ] |
| 219 | +); |
| 220 | + |
| 221 | +test_binary_impl!( |
| 222 | + test_rhypot, |
| 223 | + F, |
| 224 | + F::rhypot, |
| 225 | + [ |
| 226 | + { |
| 227 | + input_vectorization: 1, |
| 228 | + out_vectorization: 1, |
| 229 | + lhs: as_type![F: 3., 0., 5., 0.], |
| 230 | + rhs: as_type![F: 4., 5., 0., 0.], |
| 231 | + expected: &[F::new(0.2), F::new(0.2), F::new(0.2), F::INFINITY] |
| 232 | + }, |
| 233 | + { |
| 234 | + input_vectorization: 2, |
| 235 | + out_vectorization: 2, |
| 236 | + lhs: as_type![F: 3., 0., 5., 0.3], |
| 237 | + rhs: as_type![F: 4., 5., 0., 0.4], |
| 238 | + expected: as_type![F: 0.2, 0.2, 0.2, 2.] |
| 239 | + }, |
| 240 | + { |
| 241 | + input_vectorization: 4, |
| 242 | + out_vectorization: 4, |
| 243 | + lhs: as_type![F: 0., 0., -5., -0.3], |
| 244 | + rhs: as_type![F: -1., -5., 0., -0.4], |
| 245 | + expected: as_type![F: 1., 0.2, 0.2, 2.] |
| 246 | + } |
| 247 | + ] |
| 248 | +); |
| 249 | + |
187 | 250 | #[cube(launch_unchecked)] |
188 | 251 | fn test_powi_kernel<F: Float>( |
189 | 252 | lhs: &Array<Line<F>>, |
@@ -354,6 +417,8 @@ macro_rules! testgen_binary { |
354 | 417 |
|
355 | 418 | add_test!(test_dot); |
356 | 419 | add_test!(test_powf); |
| 420 | + add_test!(test_hypot); |
| 421 | + add_test!(test_rhypot); |
357 | 422 | add_test!(test_powi); |
358 | 423 | add_test!(test_atan2); |
359 | 424 | } |
|
0 commit comments