Skip to content

Commit bd29cf3

Browse files
authored
Remove Sampler copy stream (#2209)
1 parent 31bff69 commit bd29cf3

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class Sampler(nn.Module):
3030
def __init__(self, vocab_size: int) -> None:
3131
super().__init__()
3232
self.vocab_size = vocab_size
33-
self._copy_stream: torch.cuda.Stream = torch.cuda.Stream()
3433

3534
def forward(
3635
self,
@@ -51,14 +50,10 @@ def forward(
5150
# Apply logits processors (if any).
5251
logits = _apply_logits_processors(logits, sampling_metadata)
5352

54-
# Prepare sampling tensors in another stream to overlap
55-
# CPU<->GPU data transfer with GPU computation in forward pass.
56-
with torch.cuda.stream(self._copy_stream):
57-
(sampling_tensors, do_penalties, do_top_p_top_k,
58-
do_min_p) = SamplingTensors.from_sampling_metadata(
59-
sampling_metadata, vocab_size, logits.device, logits.dtype)
60-
61-
torch.cuda.current_stream().wait_stream(self._copy_stream)
53+
# Prepare sampling tensors with pinned memory to avoid blocking.
54+
(sampling_tensors, do_penalties, do_top_p_top_k,
55+
do_min_p) = SamplingTensors.from_sampling_metadata(
56+
sampling_metadata, vocab_size, logits.device, logits.dtype)
6257

6358
# Apply presence and frequency penalties.
6459
if do_penalties:

0 commit comments

Comments
 (0)