@@ -30,7 +30,6 @@ class Sampler(nn.Module):
30
30
def __init__ (self , vocab_size : int ) -> None :
31
31
super ().__init__ ()
32
32
self .vocab_size = vocab_size
33
- self ._copy_stream : torch .cuda .Stream = torch .cuda .Stream ()
34
33
35
34
def forward (
36
35
self ,
@@ -51,14 +50,10 @@ def forward(
51
50
# Apply logits processors (if any).
52
51
logits = _apply_logits_processors (logits , sampling_metadata )
53
52
54
- # Prepare sampling tensors in another stream to overlap
55
- # CPU<->GPU data transfer with GPU computation in forward pass.
56
- with torch .cuda .stream (self ._copy_stream ):
57
- (sampling_tensors , do_penalties , do_top_p_top_k ,
58
- do_min_p ) = SamplingTensors .from_sampling_metadata (
59
- sampling_metadata , vocab_size , logits .device , logits .dtype )
60
-
61
- torch .cuda .current_stream ().wait_stream (self ._copy_stream )
53
+ # Prepare sampling tensors with pinned memory to avoid blocking.
54
+ (sampling_tensors , do_penalties , do_top_p_top_k ,
55
+ do_min_p ) = SamplingTensors .from_sampling_metadata (
56
+ sampling_metadata , vocab_size , logits .device , logits .dtype )
62
57
63
58
# Apply presence and frequency penalties.
64
59
if do_penalties :
0 commit comments