Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/kernels/attention/test_flashinfer_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The workspace buffer is created with torch.int8 dtype, while the main implementation in vllm/v1/attention/backends/flashinfer.py uses torch.uint8. While this might not cause issues with a zero-initialized buffer, using an inconsistent data type can lead to subtle bugs if the underlying kernel has specific expectations about the data being signed or unsigned. For consistency and to prevent potential correctness issues, it's recommended to use torch.uint8 here.

Suggested change
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8)

wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The workspace buffer here is created with torch.int8, which is inconsistent with the torch.uint8 used in the main implementation. To ensure consistency across the codebase and avoid potential issues related to signed versus unsigned byte interpretation by the FlashInfer kernel, it is advisable to use torch.uint8 for this buffer as well.

Suggested change
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8)

wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(self, runner):

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
self._workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
self._workspace_buffer = torch.zeros(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to update in vllm/attention/backends/flashinfer.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks for your review!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for accidentally pushing to another PR. It's added now.

FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
Expand Down