Skip to content

Commit bee5064

Browse files
authored
[Wave] Make APLP faster (iree-org#840)
This PR adds an implementation of APLP for software pipelining in Rust to speed up the scheduling computation. --------- Signed-off-by: Harsh Menon <[email protected]>
1 parent fff9966 commit bee5064

File tree

12 files changed

+614
-23
lines changed

12 files changed

+614
-23
lines changed

.github/workflows/ci-tk.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ jobs:
5050
echo VIRTUAL_ENV=$VIRTUAL_ENV >> "$GITHUB_ENV"
5151
echo "$VENV_DIR/bin" >> "$GITHUB_PATH"
5252
53+
- name: "Setting up Rust"
54+
uses: actions-rust-lang/setup-rust-toolchain@v1
55+
with:
56+
toolchain: stable
57+
5358
- name: Install pip deps
5459
if: "(!contains(toJSON(matrix.os), 'amdgpu') && !contains(toJSON(matrix.os), 'mi300')) && !cancelled()"
5560
run: |

.github/workflows/test_build_release.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ jobs:
3939
with:
4040
python-version: ${{ matrix.version }}
4141

42+
- name: "Setting up Rust"
43+
uses: actions-rust-lang/setup-rust-toolchain@v1
44+
with:
45+
toolchain: stable
46+
4247
- name: Install dependencies
4348
run: pip install -r ./build_tools/requirements-packaging.txt
4449

build_tools/post_build_release_test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ source "${WHEELHOUSE_DIR}"/test.venv/bin/activate
1919
# --no-index is required so that we don't pick up different versions from pypi
2020
pip install --no-index -f "${WHEELHOUSE_DIR}" iree-turbine[testing]
2121
pip install --no-index -f "${WHEELHOUSE_DIR}" torchvision
22+
# Install local packages
23+
pip install ${REPO_ROOT}/iree/turbine/kernel/wave/scheduling/aplp
2224
pip freeze
2325

2426
# Run tests

docs/wave/aplp.rst

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
.. default-role:: code
2+
3+
All Pairs Longest Paths for Software Pipelining
4+
=========================================================
5+
6+
This document explains the Rust library designed for All-Pairs Longest Path (APLP) computation, specifically tailored for software pipelining. The library consists of two main modules: `prune.rs` for optimizing path representations and `lib.rs` for the core APLP logic and Python FFI (Foreign Function Interface) using PyO3.
7+
8+
The "paths" are represented as pairs `(delay, iter_diff)`, which define a line :math:`L(S) = \text{delay} - \text{iter_diff} \cdot S`, where :math:`S` is the symbolic initiation interval. The goal of pruning is to find the upper envelope of these lines.
9+
10+
prune.rs: Path Pruning Logic
11+
-------------------------------
12+
13+
The `prune.rs` module contains the `prune_envelope` function, which is responsible for taking a list of candidate paths (each represented as a `(delay: f32, iter_diff: f32)` tuple) for a single pair of nodes and returning a minimal list of paths that form the upper envelope. This means that for any given initiation interval `S`, one of the paths in the pruned list will provide the maximum (longest) path value.
14+
15+
**Purpose:**
16+
To reduce the number of paths that need to be considered for each pair of nodes by eliminating paths that are "dominated" by others across all possible non-negative values of the initiation interval `S`.
17+
18+
**Algorithm (`prune_envelope`):**
19+
The function implements a variation of Andrew's monotone chain algorithm, which is commonly used for finding convex hulls. The steps are:
20+
21+
1. **Handle Empty Input:** If the input list of paths is empty, return an empty list.
22+
2. **Initial Filter & Sort:**
23+
24+
* Filter out exact duplicate paths (considering floating point precision).
25+
* Sort paths primarily by delay (descending) and then by iteration difference (ascending) for a canonical order. This helps in the subsequent steps.
26+
3. **Separate Path Types:**
27+
28+
* Paths with `-infinity` delay are handled separately. If only such paths exist, those with the numerically smallest `iter_diff` are kept (as :math:`L(S) = -\infty - \text{iter_diff} \cdot S` means a smaller `iter_diff` is "less negative" or "longer" when :math:`S > 0`).
29+
* Paths with finite delays are processed further. If no finite paths exist, the result from infinite paths is returned.
30+
4. **Transform to Lines:** Finite paths `(delay, iter_diff)` are transformed into lines represented as `(slope, intercept)`, where:
31+
32+
* `slope (m) = -iter_diff`
33+
* `intercept (c) = delay`
34+
So, the line equation becomes :math:`L(S) = c + m \cdot S`.
35+
5. **Filter Unique Slopes:** For lines with the same slope, only the one with the highest intercept is kept, as it will always be above or equal to others with the same slope. The lines are then sorted by slope `m` in ascending order.
36+
6. **Build Upper Envelope (Monotone Chain Scan):** This is the core adaptation of Andrew's algorithm.
37+
38+
* Iterate through the sorted lines `(m, c)`.
39+
* Maintain a list (`envelope_lines_mc`) of lines currently forming the upper envelope.
40+
* For each `current_line`, check if adding it to `envelope_lines_mc` would make the previously second-to-last line in the envelope redundant. This is done by checking the "turn" direction formed by the last three lines (the two in the envelope and the current one). If they don't form a "right turn" (i.e., the middle line is below or on the segment formed by the other two, indicating it's not part of the convex upper envelope), the last line is popped from `envelope_lines_mc`.
41+
* The `current_line` is then added to `envelope_lines_mc`.
42+
7. **Convert Back:** The lines in `envelope_lines_mc` are converted back from `(slope, intercept)` to `(delay, iter_diff)` format. `iter_diff` is rounded as it typically represents an integer count.
43+
44+
**Detailed Explanation: Andrew's Monotone Chain for Upper Envelope**
45+
46+
Andrew's monotone chain algorithm is typically used to find the convex hull of a set of 2D points. It works by first sorting the points (usually by x-coordinate, then y-coordinate) and then constructing the upper and lower hulls in separate passes. For our purpose of finding the upper envelope of lines :math:`L(S) = c + mS` (where :math:`c=\text{delay}`, :math:`m=-\text{iter_diff}`), we adapt this:
47+
48+
1. **Point Representation:** We consider the lines in their dual form or by their parameters. In our case, we sort lines by their slopes `m` (which is `-iter_diff`). If slopes are equal, we only keep the line with the highest intercept `c` (delay), as it dominates others with the same slope.
49+
2. **Monotonicity:** The algorithm relies on processing points (or in our case, lines sorted by slope) in a specific order.
50+
3. **Building the Hull (Upper Envelope):**
51+
52+
* We iterate through the unique-slope lines, sorted by slope `m`.
53+
* We maintain a candidate list for the upper envelope (e.g., `envelope_lines_mc`).
54+
* When considering adding a new line (`current_line`) to the envelope:
55+
56+
* Let the last two lines in the envelope be `L1` (second to last) and `L2` (last).
57+
* We check if the sequence `L1, L2, current_line` maintains the convexity required for an upper envelope. For an upper envelope of lines :math:`y = mx + c` where lines are sorted by increasing slope :math:`m`, we need the intersection point of :math:`(L1, L2)` to be to the left of the intersection point of :math:`(L2, \text{current_line})`.
58+
* This is equivalent to checking the "turn" direction. If adding `current_line` causes a "non-right turn" (i.e., a left turn or collinearity that makes `L2` redundant), `L2` is removed from the envelope. This check is repeated until the condition is met or the envelope has fewer than two lines.
59+
* The geometric check can be performed using a cross-product like condition without explicitly calculating intersection points:
60+
For lines :math:`L_1=(m_1, c_1)`, :math:`L_2=(m_2, c_2)`, and :math:`L_3=(m_3, c_3)` with :math:`m_1 < m_2 < m_3`:
61+
:math:`L_2` is part of the upper envelope if the intersection of :math:`L_1, L_2` occurs at an :math:`S`-value less than the intersection of :math:`L_2, L_3`.
62+
The intersection :math:`S`-value for :math:`L_a, L_b` is :math:`S_{ab} = (c_b - c_a) / (m_a - m_b)`.
63+
So, we need :math:`(c_2 - c_1) / (m_1 - m_2) < (c_3 - c_2) / (m_2 - m_3)`.
64+
Rearranging to avoid division (and being careful with signs and slope ordering, :math:`m_1 < m_2 < m_3` implies :math:`m_1-m_2 < 0` and :math:`m_2-m_3 < 0`):
65+
:math:`(c_2 - c_1)(m_2 - m_3) > (c_3 - c_2)(m_1 - m_2)` (for strictly convex).
66+
The implementation uses `(c1 - c2)*(current_m - m2) >= (c2 - current_c)*(m2 - m1)` to pop `L2` if it's redundant. This formulation correctly identifies when `L2` is "below" the segment formed by `L1` and `current_line` or collinear in a way that makes it non-essential for the upper envelope.
67+
* After the check, `current_line` is added to the envelope.
68+
4. **Result:** The final list `envelope_lines_mc` contains the lines that constitute the upper envelope.
69+
70+
This process ensures that only the lines that are maximal for some range of `S` are kept. The overall time complexity is dominated by the initial sort, making it :math:`O(N \log N)` where :math:`N` is the number of initial lines.
71+
72+
**Visualization of Upper Envelope Construction (Monotone Chain Idea):**
73+
74+
.. mermaid::
75+
76+
graph TD
77+
Start["Start with lines sorted by slope m: L1, L2, L3, ..."] --> P1["Initialize Envelope_List = []"];
78+
P1 --> ForEach["For each Line_current (Lc) in sorted lines:"];
79+
ForEach --> CheckSize{"len(Envelope_List) < 2?"};
80+
CheckSize -- Yes --> AddLc1["Add Lc to Envelope_List"];
81+
AddLc1 --> ForEach;
82+
CheckSize -- No --> GetPrevLines["L2 = Envelope_List.last()\nL1 = Envelope_List.second_last()"];
83+
GetPrevLines --> TurnCheck{"Is L1-L2-Lc a 'right turn' (maintains upper convexity)?"};
84+
TurnCheck -- No (L2 is redundant) --> PopL2["Pop L2 from Envelope_List"];
85+
PopL2 --> CheckSize2{"len(Envelope_List) < 2?"};
86+
CheckSize2 -- Yes --> AddLc2["Add Lc to Envelope_List"];
87+
AddLc2 --> ForEach;
88+
CheckSize2 -- No --> GetPrevLines;
89+
TurnCheck -- Yes --> AddLc3["Add Lc to Envelope_List"];
90+
AddLc3 --> ForEach;
91+
ForEach -- All lines processed --> End["End: Envelope_List contains the upper envelope lines"];
92+
93+
**References for Convex Hull Algorithms:**
94+
95+
* A.M. Andrew, "Another efficient algorithm for convex hulls in two dimensions", Info. Proc. Letters 9, 216-219 (1979).
96+
* Joseph O'Rourke, "Computational Geometry in C", 2nd Edition, Cambridge University Press (1998). (Chapter on Convex Hulls)
97+
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein, "Introduction to Algorithms", 3rd Edition, MIT Press (2009). (Chapter 33: Computational Geometry)
98+
99+
**Mermaid Diagram for `prune_envelope` (Overall Flow):**
100+
101+
.. mermaid::
102+
103+
graph TD
104+
A["Input: List of (delay, iter_diff) paths"] --> B{Handle -INF paths};
105+
B -- Finite Paths --> C["Transform to lines: (m=-iter_diff, c=delay)"];
106+
B -- Only -INF Paths --> D["Keep paths with min iter_diff"];
107+
C --> E["Filter unique slopes, keeping max intercept"];
108+
E --> F["Sort lines by slope 'm'"];
109+
F --> G["Build upper envelope (Monotone Chain Scan - see detailed diagram above)"];
110+
G --> H["Convert envelope lines back to (delay, iter_diff)"];
111+
H --> Z["Output: Pruned list of paths"];
112+
D --> Z;
113+
114+
lib.rs: APLP Computation and Python Interface
115+
---------------------------------------------
116+
117+
The `lib.rs` module orchestrates the All-Pairs Longest Path computation and exposes the functionality to Python using PyO3.
118+
119+
**Data Structures:**
120+
121+
* **`PyRawEdge(u32, u32, f32, f32)`:** A Rust tuple struct that maps directly to Python tuples `(from_node_idx, to_node_idx, delay, iter_diff)` passed from Python. It uses `#[derive(FromPyObject)]` for automatic conversion.
122+
* **Internal Path Representation:** Within Rust, paths for each pair of nodes `(u,v)` are stored as `Vec<(f32, f32)>`, representing the list of `(delay, iter_diff)` tuples that form the upper envelope for that pair.
123+
124+
**Core Logic (`compute_aplp_internal`):**
125+
126+
This function implements the Floyd-Warshall algorithm to compute APLP.
127+
128+
1. **Initialization:**
129+
* A 3D vector `d_current_vec_vec[i][j]` is initialized. Each element `d_current_vec_vec[u_idx][v_idx]` stores a `Vec<(f32, f32)>` representing the pruned paths from node `u` to node `v`.
130+
* For self-paths: `d_current_vec_vec[i][i]` is initialized to `[(0.0, 0.0)]` (a zero-delay, zero-iteration-difference path from a node to itself) after pruning.
131+
* For direct edges `(u,v)` from the input `raw_edges`: the tuple `(edge.delay, edge.iter_diff)` is added to the list in `d_current_vec_vec[u_idx][v_idx]`, which is then pruned.
132+
2. **Floyd-Warshall Iteration:**
133+
* The algorithm iterates `k_idx` from `0` to `node_count - 1` (representing the intermediate node).
134+
* **Parallelization (Rayon):** For each `k_idx`, the computation of rows `i_idx` is parallelized using `rayon::into_par_iter()`.
135+
* An `Arc` (Atomically Reference Counted pointer) is used to safely share the `d_current_vec_vec` matrix (from the previous `k` iteration) among worker threads.
136+
* Each worker thread processes one or more rows `i_idx`.
137+
* **Inner Loops (Worker Thread):** For each pair of nodes `(i_idx, j_idx)`:
138+
* It considers paths from `i_idx` to `k_idx` and from `k_idx` to `j_idx`.
139+
* If such sub-paths exist, they are combined:
140+
`new_delay = d_ik + d_kj`
141+
`new_iter_diff = id_ik + id_kj`
142+
* These newly formed paths are added to the existing list of paths for `(i_idx, j_idx)`.
143+
* The combined list is then pruned using `prune::prune_envelope`.
144+
* The result is stored in a `d_next_rows` structure.
145+
* After all rows `i_idx` are processed for the current `k_idx`, `d_current_arc` is updated to point to the newly computed matrix (from `d_next_rows`).
146+
3. **Result:** After all `k_idx` iterations, the final `d_current_arc` contains the APLP results.
147+
148+
**FFI Function (`perform_aplp_pyo3`):**
149+
150+
This function is exposed to Python using the `#[pyfunction]` macro.
151+
152+
1. **Input:** Takes `node_count: usize` and `edges_pylist: &PyList` (a Python list of edge tuples) as input.
153+
2. **Conversion:** Converts the Python list of edge tuples into a `Vec<PyRawEdge>` using `extract()` which leverages the `FromPyObject` derive on `PyRawEdge`.
154+
3. **Computation:** Calls `compute_aplp_internal` to perform the APLP. The `py.allow_threads(|| ...)` block releases the Python Global Interpreter Lock (GIL) during the potentially long computation, allowing Rust's Rayon parallelism to be effective.
155+
4. **Output Conversion:** Converts the resulting Rust matrix `Vec<Vec<Vec<(f32,f32)>>>` into a Python dictionary.
156+
* The dictionary keys are Python tuples `(u_idx, v_idx)`.
157+
* The dictionary values are Python lists of Python tuples `[(delay, iter_diff), ...]`.
158+
5. **Return:** Returns the Python dictionary to the Python caller.
159+
160+
**Python Module Definition (`aplp_rs_lib`):**
161+
The `#[pymodule]` macro defines the Python module. The `perform_aplp_pyo3` function is added to this module, making it callable from Python as `aplp_rs_lib.perform_aplp_pyo3(...)`.
162+
163+
**Mermaid Diagram for APLP Computation Flow:**
164+
165+
.. mermaid::
166+
167+
graph TD
168+
subgraph PythonSide [Python Caller]
169+
PyInput["Input: node_count, list_of_edge_tuples"]
170+
end
171+
172+
subgraph RustFFI [Rust: perform_aplp_pyo3]
173+
direction LR
174+
ConvertPyInput["Convert Python list of edge tuples to Vec<PyRawEdge>"]
175+
CallInternal["Call compute_aplp_internal(node_count, rust_edges)"]
176+
ConvertRustOutput["Convert Rust D_matrix to Python Dict"]
177+
end
178+
179+
subgraph RustInternalCompute [Rust: compute_aplp_internal]
180+
direction TB
181+
InitD["Initialize D matrix: D[i][i] = [(0,0)], direct edges + prune"]
182+
LoopK["Loop k from 0 to V-1 (Intermediate Node)"]
183+
subgraph ParallelForRowI [For each k: Parallelize 'i' Loop]
184+
direction TB
185+
MapI["map_with(d_current_arc, i_idx)"]
186+
subgraph WorkerForRowI ["Worker for row 'i'"]
187+
direction TB
188+
LoopJ["Loop j from 0 to V-1 (Destination Node)"]
189+
CombinePaths["Combine D[i][k] + D[k][j]"]
190+
AddToExisting["Add to existing D[i][j] paths"]
191+
Prune["Call prune_envelope()"]
192+
StoreResult["Store pruned D_next[i][j]"]
193+
end
194+
MapI --> WorkerForRowI
195+
end
196+
InitD --> LoopK
197+
LoopK --> ParallelForRowI
198+
ParallelForRowI --> CollectResults["Collect results into D_next matrix"]
199+
CollectResults --> UpdateD["Update D_current = D_next"]
200+
UpdateD --> LoopK
201+
LoopK -- After all k --> FinalDMatrix["Final D_matrix"]
202+
end
203+
204+
subgraph FinalDMatrix
205+
PyOutput["Output: Python Dict {(u,v): [(d,id), ...]}"]
206+
end
207+
208+
PyInput --> RustFFI;
209+
ConvertPyInput --> CallInternal;
210+
CallInternal --> RustInternalCompute;

docs/wave/getting_started.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
Getting Started with Wave
2+
========================
3+
4+
This guide will help you get up and running with Wave, a high-performance machine learning programming language designed for accelerating ML kernel development.
5+
6+
Prerequisites
7+
------------
8+
9+
Before installing Wave, ensure you have the following prerequisites:
10+
11+
1. Python 3.10 or later
12+
2. PyTorch
13+
3. ROCm (for AMD GPU support)
14+
4. A compatible AMD GPU with ROCm support (MI250, MI300, etc.)
15+
5. Rust 1.70 or later
16+
17+
Installation
18+
-----------
19+
20+
1. Install PyTorch with ROCm support:
21+
22+
.. code-block:: bash
23+
24+
pip install -r pytorch-rocm-requirements.txt
25+
26+
2. Install Rust:
27+
28+
.. code-block:: bash
29+
30+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
31+
32+
3. Install Wave and its dependencies:
33+
34+
.. code-block:: bash
35+
36+
pip install -r requirements.txt
37+
pip install -r requirements-wave-runtime.txt
38+
39+
40+
Next Steps
41+
---------
42+
43+
- Read the :doc:`system_architecture` guide to understand Wave's compilation pipeline
44+
- Check out the :doc:`gemm_tutorial` for a more complex example
45+
- Explore :doc:`shared_memory` for optimization techniques
46+
- Learn about the :doc:`runtime` for advanced usage
47+
48+
For more detailed information about Wave's architecture and optimization passes, see the :doc:`system_architecture` documentation.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "aplp"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[lib]
7+
name = "aplp_lib"
8+
crate-type = ["cdylib"] # Compile to a C-compatible dynamic library
9+
10+
[dependencies]
11+
rayon = "1.5" # For parallelism
12+
pyo3 = { version = "0.21", features = ["extension-module"] }
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[build-system]
2+
requires = ["maturin>=1.0,<2.0"]
3+
build-backend = "maturin"
4+
5+
[project]
6+
name = "aplp"
7+
version = "0.1.0"
8+
description = "A Python package with a Rust extension for APLP computation."
9+
requires-python = ">=3.7"

0 commit comments

Comments
 (0)