-
-
Notifications
You must be signed in to change notification settings - Fork 104
Add complex numbers #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add complex numbers #849
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,110 @@ pub use amp::AMP; | |
| #[cfg(feature = "f16")] | ||
| pub use half::f16; | ||
|
|
||
| #[cfg(feature = "complex")] | ||
| pub mod complex { | ||
| use core::ops::{Deref, DerefMut}; | ||
|
|
||
| #[cfg(feature = "cuda")] | ||
| use cudarc::driver::{DeviceRepr, ValidAsZeroBits}; | ||
| use num_complex::Complex32; | ||
| use num_traits::{FromPrimitive, ToPrimitive}; | ||
|
|
||
| #[derive(PartialEq, Debug, Default, Clone, Copy)] | ||
| pub struct Complex(Complex32); | ||
| impl Deref for Complex { | ||
| type Target = Complex32; | ||
|
|
||
| fn deref(&self) -> &Self::Target { | ||
| &self.0 | ||
| } | ||
| } | ||
| impl DerefMut for Complex { | ||
| fn deref_mut(&mut self) -> &mut Self::Target { | ||
| &mut self.0 | ||
| } | ||
| } | ||
| const fn c1() -> Complex { | ||
| Complex(Complex32 { re: 1.0, im: 0.0 }) | ||
| } | ||
| impl Complex { | ||
| pub const ONE: Complex = c1(); | ||
| pub fn new(r: f32, i: f32) -> Self { | ||
| Self(num_complex::Complex { re: r, im: i }) | ||
| } | ||
| } | ||
| impl FromPrimitive for Complex { | ||
| fn from_i64(n: i64) -> Option<Self> { | ||
| Some(Complex(Complex32::from_i64(n)?)) | ||
| } | ||
|
|
||
| fn from_u64(n: u64) -> Option<Self> { | ||
| Some(Complex(Complex32::from_u64(n)?)) | ||
| } | ||
| } | ||
| impl ToPrimitive for Complex { | ||
| fn to_i64(&self) -> Option<i64> { | ||
| self.0.to_i64() | ||
| } | ||
|
|
||
| fn to_u64(&self) -> Option<u64> { | ||
| self.0.to_u64() | ||
| } | ||
| } | ||
|
|
||
| impl std::ops::Add<Self> for Complex { | ||
| type Output = Self; | ||
| fn add(self, rhs: Self) -> Self::Output { | ||
| Self(self.0 + rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::Sub<Self> for Complex { | ||
| type Output = Self; | ||
|
|
||
| fn sub(self, rhs: Self) -> Self::Output { | ||
| Self(self.0 - rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::Mul<Self> for Complex { | ||
| type Output = Self; | ||
|
|
||
| fn mul(self, rhs: Self) -> Self::Output { | ||
| Self(self.0 * rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::Div<Self> for Complex { | ||
| type Output = Self; | ||
|
|
||
| fn div(self, rhs: Self) -> Self::Output { | ||
| Self(self.0 / rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::AddAssign for Complex { | ||
| fn add_assign(&mut self, rhs: Self) { | ||
| self.0.add_assign(rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::SubAssign for Complex { | ||
| fn sub_assign(&mut self, rhs: Self) { | ||
| self.0.sub_assign(rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::MulAssign for Complex { | ||
| fn mul_assign(&mut self, rhs: Self) { | ||
| self.0.mul_assign(rhs.0) | ||
| } | ||
| } | ||
| impl std::ops::DivAssign for Complex { | ||
| fn div_assign(&mut self, rhs: Self) { | ||
| self.0.div_assign(rhs.0) | ||
| } | ||
| } | ||
| #[cfg(feature = "cuda")] | ||
| unsafe impl ValidAsZeroBits for Complex {} | ||
| #[cfg(feature = "cuda")] | ||
| unsafe impl DeviceRepr for Complex {} | ||
|
Comment on lines
+115
to
+118
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can add these to cudarc in a PR there behind a feature flag, that should allow us to not need the wrapper type, right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was unaware that you controlled cudarc. That would work yes. |
||
| } | ||
|
|
||
| /// Represents a type where all 0 bits is a valid pattern. | ||
| #[cfg(not(feature = "cuda"))] | ||
| pub trait SafeZeros {} | ||
|
|
@@ -30,7 +134,7 @@ pub trait Unit: | |
| + Default | ||
| + std::fmt::Debug | ||
| + PartialEq | ||
| + PartialOrd | ||
| // + PartialOrd | ||
| + Send | ||
| + Sync | ||
| + std::marker::Unpin | ||
|
|
@@ -65,6 +169,8 @@ unit!(i128, 1); | |
| unit!(bool, true); | ||
| #[cfg(feature = "f16")] | ||
| unit!(f16, f16::ONE); | ||
| #[cfg(feature = "complex")] | ||
| unit!(complex::Complex, complex::Complex::ONE); | ||
|
|
||
| /// Represents something that has a [Unit]. | ||
| pub trait HasUnitType { | ||
|
|
@@ -105,6 +211,8 @@ impl Dtype for u128 {} | |
| impl Dtype for usize {} | ||
| #[cfg(feature = "f16")] | ||
| impl Dtype for f16 {} | ||
| #[cfg(feature = "complex")] | ||
| impl Dtype for complex::Complex {} | ||
|
|
||
| /// Represents something that has a [Dtype]. | ||
| pub trait HasDtype { | ||
|
|
@@ -129,3 +237,5 @@ impl NotMixedPrecision for u128 {} | |
| impl NotMixedPrecision for usize {} | ||
| #[cfg(feature = "f16")] | ||
| impl NotMixedPrecision for f16 {} | ||
| #[cfg(feature = "complex")] | ||
| impl NotMixedPrecision for complex::Complex {} | ||
Uh oh!
There was an error while loading. Please reload this page.