Skip to content

Commit 7ae45be

Browse files
committed
Merge remote-tracking branch 'nemo/main' into magpietts_evaluation_parallelization
2 parents 8fc8190 + 8cad3a6 commit 7ae45be

File tree

8 files changed

+51
-18
lines changed

8 files changed

+51
-18
lines changed

nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ def _full_graph_compile(self):
586586
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
587587
):
588588
self._before_process_batch()
589-
capture_status, _, graph, _, _, _ = cu_call(
589+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
590+
capture_status, _, graph, *_ = cu_call(
590591
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
591592
)
592593

nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,8 @@ def _full_graph_compile(self):
874874
):
875875
self._before_loop()
876876

877-
capture_status, _, graph, _, _, _ = cu_call(
877+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
878+
capture_status, _, graph, *_ = cu_call(
878879
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
879880
)
880881
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
205205
# Get max sequence length
206206
self.max_out_len_t = self.encoder_output_length.max()
207207

208-
capture_status, _, graph, _, _, _ = cu_call(
208+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
209+
capture_status, _, graph, *_ = cu_call(
209210
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.device).cuda_stream)
210211
)
211212
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,8 @@ def _full_graph_compile(self):
906906
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
907907
):
908908
self._before_loop()
909-
capture_status, _, graph, _, _, _ = cu_call(
909+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
910+
capture_status, _, graph, *_ = cu_call(
910911
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
911912
)
912913

nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,8 @@ def _full_graph_compile(self):
987987
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
988988
):
989989
self._before_loop()
990-
capture_status, _, graph, _, _, _ = cu_call(
990+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
991+
capture_status, _, graph, *_ = cu_call(
991992
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
992993
)
993994

nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,8 @@ def _full_graph_compile(self):
968968
):
969969
self._before_outer_loop()
970970

971-
capture_status, _, graph, _, _, _ = cu_call(
971+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
972+
capture_status, _, graph, *_ = cu_call(
972973
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
973974
)
974975
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,8 @@ def _full_graph_compile(self):
10411041
):
10421042
self._before_outer_loop()
10431043

1044-
capture_status, _, graph, _, _, _ = cu_call(
1044+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
1045+
capture_status, _, graph, *_ = cu_call(
10451046
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream)
10461047
)
10471048
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive

nemo/core/utils/cuda_python_utils.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import inspect
1617

1718
import numpy as np
1819
import torch
@@ -125,7 +126,8 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
125126
from cuda.bindings import driver as cuda
126127
from cuda.bindings import runtime as cudart
127128

128-
capture_status, _, graph, _, _, _ = cu_call(
129+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
130+
capture_status, _, graph, *_ = cu_call(
129131
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream)
130132
)
131133
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
@@ -144,7 +146,8 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
144146
0,
145147
)
146148

147-
capture_status, _, graph, dependencies, _, _ = cu_call(
149+
# NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements
150+
capture_status, _, graph, dependencies, *_ = cu_call(
148151
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream)
149152
)
150153
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive
@@ -168,18 +171,41 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi
168171
# Use driver API here because of bug in cuda-python runtime API: https://github.com/NVIDIA/cuda-python/issues/55
169172
# TODO: Change call to this after fix goes in (and we bump minimum cuda-python version to 12.4.0):
170173
# node, = cu_call(cudart.cudaGraphAddNode(graph, dependencies, len(dependencies), driver_params))
171-
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, None, len(dependencies), driver_params))
174+
# depending on cuda-python version, number of parameters vary
175+
num_cuda_graph_add_node_params = len(inspect.signature(cuda.cuGraphAddNode).parameters)
176+
if num_cuda_graph_add_node_params == 5:
177+
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, None, len(dependencies), driver_params))
178+
elif num_cuda_graph_add_node_params == 4:
179+
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params))
180+
else:
181+
raise NeMoCUDAPythonException("Unexpected number of parameters for `cuGraphAddNode`")
172182
body_graph = driver_params.conditional.phGraph_out[0]
173183

174-
cu_call(
175-
cudart.cudaStreamUpdateCaptureDependencies(
176-
torch.cuda.current_stream(device=device).cuda_stream,
177-
[node],
178-
None,
179-
1,
180-
cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies,
181-
)
184+
# depending on cuda-python version, number of parameters vary
185+
num_cuda_stream_update_capture_dependencies_params = len(
186+
inspect.signature(cudart.cudaStreamUpdateCaptureDependencies).parameters
182187
)
188+
if num_cuda_stream_update_capture_dependencies_params == 5:
189+
cu_call(
190+
cudart.cudaStreamUpdateCaptureDependencies(
191+
torch.cuda.current_stream(device=device).cuda_stream,
192+
[node],
193+
None,
194+
1,
195+
cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies,
196+
)
197+
)
198+
elif num_cuda_stream_update_capture_dependencies_params == 4:
199+
cu_call(
200+
cudart.cudaStreamUpdateCaptureDependencies(
201+
torch.cuda.current_stream(device=device).cuda_stream,
202+
[node],
203+
1,
204+
cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies,
205+
)
206+
)
207+
else:
208+
raise NeMoCUDAPythonException("Unexpected number of parameters for `cudaStreamUpdateCaptureDependencies`")
183209
body_stream = torch.cuda.Stream(device)
184210
previous_stream = torch.cuda.current_stream(device=device)
185211
cu_call(

0 commit comments

Comments
 (0)