diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index e87ce520bc66..e5364e3e681a 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -113,7 +113,7 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -247,7 +247,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout) wrapper.plan(q_indptr, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 78d8a67e37f8..4e7b5c2255f3 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -204,7 +204,7 @@ def __init__(self, runner): def _get_workspace_buffer(self): if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( + self._workspace_buffer = torch.zeros( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.runner.device) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c85d8bce31f5..785f40208ed7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -251,7 +251,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def _get_workspace_buffer(self): if self._workspace_buffer is None: - self._workspace_buffer = torch.empty( + self._workspace_buffer = torch.zeros( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device)