-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add selective_scan compilable/exportable custom_op #651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bhack
wants to merge
5
commits into
state-spaces:main
Choose a base branch
from
bhack:selective_scan_custom_op
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,272 @@ | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from einops import rearrange | ||
| from typing import Optional, Tuple | ||
|
|
||
| import selective_scan_cuda | ||
|
|
||
|
|
||
| @torch.library.custom_op( | ||
| "custom_ops::selective_scan_fwd", | ||
| device_types=["cuda"], | ||
| mutates_args=(), | ||
| ) | ||
| def custom_selective_scan_fwd( | ||
| u: torch.Tensor, | ||
| delta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| C: torch.Tensor, | ||
| D: Optional[torch.Tensor], | ||
| z: Optional[torch.Tensor], | ||
| delta_bias: Optional[torch.Tensor], | ||
| delta_softplus: bool, | ||
| return_last_state: bool, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool, bool]: | ||
| pass | ||
|
|
||
| @torch.library.register_fake("custom_ops::selective_scan_fwd") | ||
| def custom_selective_scan_fwd_fake( | ||
| u, | ||
| delta, | ||
| A, | ||
| B, | ||
| C, | ||
| D, | ||
| z, | ||
| delta_bias, | ||
| delta_softplus, | ||
| return_last_state, | ||
| ): | ||
| final_out = torch.empty_like(u) | ||
| dstate = A.size(1) * (2 if A.is_complex() else 1) | ||
| last_state_fake = u.new_empty((u.size(0), u.size(1), dstate)) if return_last_state else u.new_empty(0) | ||
| out_fake = torch.empty_like(u) | ||
| x_fake = u.new_empty((u.size(0), u.size(1), u.size(2), 2 * dstate)) | ||
| return final_out, last_state_fake, out_fake, x_fake, False, False, z is not None | ||
|
|
||
| @torch.library.register_kernel("custom_ops::selective_scan_fwd", "cuda") | ||
| def custom_selective_scan_fwd_cuda( | ||
| u: torch.Tensor, | ||
| delta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| C: torch.Tensor, | ||
| D: Optional[torch.Tensor], | ||
| z: Optional[torch.Tensor], | ||
| delta_bias: Optional[torch.Tensor], | ||
| delta_softplus: bool, | ||
| return_last_state: bool, | ||
| ): | ||
| if u.stride(-1) != 1: | ||
| u = u.contiguous() | ||
| if delta.stride(-1) != 1: | ||
| delta = delta.contiguous() | ||
| if D is not None: | ||
| D = D.contiguous() | ||
| if B.stride(-1) != 1: | ||
| B = B.contiguous() | ||
| if C.stride(-1) != 1: | ||
| C = C.contiguous() | ||
| if z is not None and z.stride(-1) != 1: | ||
| z = z.contiguous() | ||
|
|
||
| squeeze_B = False | ||
| if B.dim() == 3: | ||
| B = rearrange(B, "b dstate l -> b 1 dstate l").contiguous() | ||
| squeeze_B = True | ||
|
|
||
| squeeze_C = False | ||
| if C.dim() == 3: | ||
| C = rearrange(C, "b dstate l -> b 1 dstate l").contiguous() | ||
| squeeze_C = True | ||
|
|
||
| out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) | ||
| has_z = z is not None | ||
| final_out = rest[0].clone() if has_z else out.clone() | ||
| last_state = x[:, :, -1, 1::2].clone() if return_last_state else u.new_empty(0) | ||
| return final_out, last_state, out, x, squeeze_B, squeeze_C, has_z | ||
|
|
||
| @torch.library.custom_op( | ||
| "custom_ops::selective_scan_bwd", | ||
| device_types=["cuda"], | ||
| mutates_args=(), | ||
| ) | ||
| def custom_selective_scan_bwd( | ||
| dout: torch.Tensor, | ||
| u: torch.Tensor, | ||
| delta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| C: torch.Tensor, | ||
| D: Optional[torch.Tensor], | ||
| z: Optional[torch.Tensor], | ||
| delta_bias: Optional[torch.Tensor], | ||
| delta_softplus: bool, | ||
| out: torch.Tensor, | ||
| x: torch.Tensor, | ||
| squeeze_B: bool, | ||
| squeeze_C: bool, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| pass | ||
|
|
||
| @torch.library.register_fake("custom_ops::selective_scan_bwd") | ||
| def custom_selective_scan_bwd_fake( | ||
| dout, | ||
| u, | ||
| delta, | ||
| A, | ||
| B, | ||
| C, | ||
| D, | ||
| z, | ||
| delta_bias, | ||
| delta_softplus, | ||
| out, | ||
| x, | ||
| squeeze_B, | ||
| squeeze_C, | ||
| ): | ||
| du = torch.empty_like(u) | ||
| ddelta = torch.empty_like(delta) | ||
| dA = torch.empty_like(A) | ||
| dB = torch.empty_like(B) | ||
| dC = torch.empty_like(C) | ||
| dD = torch.empty_like(D) if (D is not None and D.numel() > 0) else u.new_empty(0) | ||
| dz = torch.empty_like(z) if (z is not None and z.numel() > 0) else u.new_empty(0) | ||
| ddelta_bias = torch.empty_like(delta_bias) if (delta_bias is not None and delta_bias.numel() > 0) else u.new_empty(0) | ||
| return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias | ||
|
|
||
| @torch.library.register_kernel("custom_ops::selective_scan_bwd", "cuda") | ||
| def custom_selective_scan_bwd_cuda( | ||
| dout: torch.Tensor, | ||
| u: torch.Tensor, | ||
| delta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| C: torch.Tensor, | ||
| D: Optional[torch.Tensor], | ||
| z: Optional[torch.Tensor], | ||
| delta_bias: Optional[torch.Tensor], | ||
| delta_softplus: bool, | ||
| out: torch.Tensor, | ||
| x: torch.Tensor, | ||
| squeeze_B: bool, | ||
| squeeze_C: bool, | ||
| ): | ||
| if dout.stride(-1) != 1: | ||
| dout = dout.contiguous() | ||
| B = B.contiguous() | ||
bhack marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| C = C.contiguous() | ||
|
|
||
| results = selective_scan_cuda.bwd( | ||
| u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, delta_softplus, False | ||
| ) | ||
| has_z = z is not None | ||
| if has_z: | ||
| du, ddelta, dA, dB, dC, dD, ddelta_bias, dz = results | ||
| else: | ||
| du, ddelta, dA, dB, dC, dD, ddelta_bias = results | ||
| dz = u.new_empty(0) | ||
|
|
||
| if squeeze_B and dB.numel() > 0: | ||
| dB = dB.squeeze(1) | ||
| if squeeze_C and dC.numel() > 0: | ||
| dC = dC.squeeze(1) | ||
|
|
||
| return du, ddelta, dA, dB, dC, dD, dz, ddelta_bias | ||
|
|
||
| def custom_bridge(ctx, *grads): | ||
| dout = grads[0] if grads else ctx.saved_tensors[0].new_empty(0) | ||
| saved = ctx.saved_tensors | ||
| if not ctx.has_z: | ||
| u, delta, A, B, C, D, delta_bias, x, out = saved | ||
| z = None | ||
| else: | ||
| u, delta, A, B, C, D, z, delta_bias, x, out = saved | ||
|
|
||
| du, ddelta, dA, dB, dC, dD, dz, ddelta_bias = torch.ops.custom_ops.selective_scan_bwd( | ||
| dout, | ||
| u, | ||
| delta, | ||
| A, | ||
| B, | ||
| C, | ||
| D, | ||
| z, | ||
| delta_bias, | ||
| ctx.delta_softplus, | ||
| out, | ||
| x, | ||
| ctx.squeeze_B, | ||
| ctx.squeeze_C | ||
| ) | ||
|
|
||
| return ( | ||
| du, | ||
| ddelta, | ||
| dA, | ||
| dB, | ||
| dC, | ||
| dD if D is not None else None, | ||
| dz if z is not None else None, | ||
| ddelta_bias if delta_bias is not None else None, | ||
| None, | ||
| None, | ||
| ) | ||
|
|
||
| def custom_setup_context(ctx, inputs, output): | ||
| (u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) = inputs | ||
| (final_out, last_state, out, x, squeeze_B, squeeze_C, has_z) = output | ||
|
|
||
| ctx.delta_softplus = delta_softplus | ||
| ctx.squeeze_B = squeeze_B | ||
| ctx.squeeze_C = squeeze_C | ||
| ctx.has_z = has_z | ||
|
|
||
| B = B.contiguous() | ||
| C = C.contiguous() | ||
| if squeeze_B and B.dim() == 3: | ||
| B = rearrange(B, "b dstate l -> b 1 dstate l").contiguous() | ||
| if squeeze_C and C.dim() == 3: | ||
| C = rearrange(C, "b dstate l -> b 1 dstate l").contiguous() | ||
|
|
||
| if not has_z: | ||
| ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, out) | ||
| else: | ||
| ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) | ||
|
|
||
| torch.library.register_autograd( | ||
| "custom_ops::selective_scan_fwd", custom_bridge, setup_context=custom_setup_context | ||
| ) | ||
|
|
||
| def selective_scan_fn_custom_op( | ||
| u: torch.Tensor, | ||
| delta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| C: torch.Tensor, | ||
| D: Optional[torch.Tensor], | ||
| z: Optional[torch.Tensor], | ||
| delta_bias: Optional[torch.Tensor], | ||
| delta_softplus: bool, | ||
| return_last_state: bool, | ||
| ) -> torch.Tensor: | ||
| # Pass all arguments positionally, exactly in schema order: | ||
| final_out, last_state, _, _, _, _, _ = torch.ops.custom_ops.selective_scan_fwd( | ||
| u, | ||
| delta, | ||
| A, | ||
| B, | ||
| C, | ||
| D, | ||
| z, | ||
| delta_bias, | ||
| delta_softplus, | ||
| return_last_state | ||
| ) | ||
|
|
||
| if return_last_state: | ||
| return final_out, last_state | ||
| else: | ||
| return final_out | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you cloning the tensor right here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the extra clone we get (not in the test but on a real training session)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh that seems weird to me no? In the CPP code we are clearly creating a new tensor for out and out_z which are independent from any input tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is coming from here:
https://github.com/pytorch/pytorch/blob/main/torch/_library/utils.py#L349C5-L372
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think that one candidate is that the same
final_outreturn is aliasing different buffers right?