Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/rust/simd/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "simd"
version = "0.1.0"
edition = "2018"

[dependencies]
wit-bindgen-rust = { git = "https://github.com/bytecodealliance/wit-bindgen.git" }

[lib]
crate-type = ["cdylib"]
25 changes: 25 additions & 0 deletions examples/rust/simd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Simd Rust

A simple simd example using [core::arch::wasm32](https://doc.rust-lang.org/core/arch/wasm32/index.html#simd)

## Build

```sh
cargo build --target wasm32-unknown-unknown
```

## Create Function

```sql
CREATE FUNCTION mul AS WASM FROM INFILE 'target/wasm32-unknown-unknown/debug/simd.wasm' WITH WIT FROM INFILE 'simd.wit'
CREATE FUNCTION dot AS WASM FROM INFILE 'target/wasm32-unknown-unknown/debug/simd.wasm' WITH WIT FROM INFILE 'simd.wit'
CREATE FUNCTION `inner` RETURNS TABLE AS WASM FROM INFILE 'target/wasm32-unknown-unknown/debug/simd.wasm' WITH WIT FROM INFILE 'simd.wit'
```

## Example Queries

```sql
SELECT mul(3,4);
SELECT dot([1,2,3], [0,5,6]);
SELECT * FROM `inner`([1,2,3], [3,4,5]);
```
4 changes: 4 additions & 0 deletions examples/rust/simd/simd.wit
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mul: func(a: u64, b: u64) -> u64
dot: func(a: list<u64>, b: list<u64>) -> u64
inner: func(a: list<u64>, b: list<u64>) -> list<u64>
mmul: func(a: list<list<u64>>, b: list<list<u64>>) -> list<list<u64>>
47 changes: 47 additions & 0 deletions examples/rust/simd/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#[cfg(target_arch = "wasm32")]
wit_bindgen_rust::export!("simd.wit");

struct Simd;

use core::arch::wasm32::*;

impl simd::Simd for Simd {
fn mul(a: u64, b: u64) -> u64 {
let va: v128 = u64x2_splat(a);
let vb: v128 = u64x2_splat(b);
let c = u64x2_extract_lane::<1>(i64x2_mul(va, vb));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is technically a scalar multiplication - it fills all lanes with the same value and then extracts just one value out of the result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your review. I wasn't familiar with simd and made a mistake. Would appreciate feedbacks on the new code.

c
}
fn dot(a: Vec<u64>, b: Vec<u64>) -> u64 {
assert!(a.len() == b.len());
let mut sum: u64 = 0;
for i in 0..a.len() {
sum += Self::mul(a[i], b[i]);
}
sum
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dot product is the smallest unit of work in matrix multiplication that can be implemented in SIMD, it usually works by taking N worth of elements from the first array and second array, multiplying them via SIMD, then adding N results to the intermediate vector sum (N is number of lanes). Intermediate sum is the added up at the end, also for input sizes not divisible by N the remainder needs to be calculated manually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Please let me know if anything I could do better. I assume floating point implementation should be similar (please let me know if it isn't) so I will update floating point examples once this is ok :)

fn inner(a: Vec<u64>, b: Vec<u64>) -> Vec<u64> {
assert!(a.len() == b.len());
let mut res = vec![0; a.len()];
for i in 0..a.len() {
res[i] = Self::mul(a[i], b[i]);
}
res
}
fn mmul(a: Vec<Vec<u64>>, b: Vec<Vec<u64>>) -> Vec<Vec<u64>> {
if a.len() == 0 && b.len() == 0 {
return Vec::with_capacity(0);
}
assert!(a[0].len() == b.len());

let mut res = vec![vec![0; a.len()]; b[0].len()];
for i in 0..a.len() {
for j in 0..b[0].len() {
for k in 0..b.len() {
res[i][j] += Self::mul(a[i][k], b[k][j]);
}
}
}
res
}
}