@@ -119,48 +119,6 @@ def _get_input_arg_device(input_args: tuple) -> torch.device:
119
119
return device
120
120
121
121
122
- # Returns True if all the input args are on a CUDA device.
123
- def _args_on_cuda (input_args : tuple ) -> bool :
124
- input_device : torch .device = _get_input_arg_device (input_args )
125
- if input_device is None :
126
- return False
127
-
128
- return input_device .type == "cuda"
129
-
130
-
131
- # Given an input list, moves the tensors to the given target_device.
132
- # The output order will be the same as the input. Non tensors will also still
133
- # be in the list.
134
- def _maybe_move_tensors_to_device (tensors : tuple ,
135
- target_device : torch .device ) -> tuple :
136
- assert target_device , "Moving tensors to None device not supported"
137
-
138
- moved_tensors = []
139
- for tensor in tensors :
140
- if not isinstance (tensor , torch .Tensor ):
141
- moved_tensors .append (tensor )
142
- continue
143
-
144
- if tensor .device == target_device :
145
- moved_tensors .append (tensor )
146
- continue
147
-
148
- if dynamo_debug :
149
- print ("Moving Tensor {} to device {}" .format (tensor , target_device ))
150
-
151
- # Have to move to CPU before moving it to target device.
152
- cpu_device : torch .device = torch .device ("cpu" )
153
- moved_tensor = tensor .to (cpu_device )
154
- moved_tensor = moved_tensor .to (target_device )
155
-
156
- # Explicitly have to copy requires_grad attribute because it's dropped
157
- # with torch.to(..)
158
- moved_tensor .requires_grad = tensor .requires_grad
159
- moved_tensors .append (moved_tensor )
160
-
161
- return tuple (moved_tensors )
162
-
163
-
164
122
def _split_xla_args_tensor_sym_constant (args ):
165
123
tensors = deque (maxlen = len (args ))
166
124
constants = []
@@ -552,14 +510,6 @@ def optimized_mod(*args: tuple):
552
510
special_return_handler , xla_args_need_update ) = extract_graph_helper (
553
511
xla_model , sym_constants_to_graph_vars )
554
512
555
- original_device : torch .device = _get_input_arg_device (args )
556
- is_cuda_args : bool = False
557
- if original_device :
558
- is_cuda_args = original_device .type == "cuda"
559
-
560
- if is_cuda_args :
561
- args = _maybe_move_tensors_to_device (args , torch_xla .device ())
562
-
563
513
if not config .skip_input_data_check :
564
514
# `torch_xla.sync()` needs to be blocking since we want to access args's
565
515
# XLADatas and they can't be placeholder.
@@ -610,11 +560,7 @@ def optimized_mod(*args: tuple):
610
560
611
561
# First few elements might be xla_args that needs to be in place updated
612
562
result = res [len (xla_args_need_update ):]
613
-
614
563
result = none_remover .add_nones (result )
615
- if is_cuda_args :
616
- result = _maybe_move_tensors_to_device (tuple (result ), original_device )
617
-
618
564
if len (result ) == 1 :
619
565
return result [0 ]
620
566
else :
@@ -802,10 +748,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
802
748
803
749
804
750
def extract_compiled_graph_helper (xla_model : torch .fx .GraphModule , xla_args ):
805
- if _args_on_cuda (xla_args ):
806
- xla_args = tuple (
807
- _maybe_move_tensors_to_device (xla_args , torch_xla .device ()))
808
-
809
751
# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
810
752
# value reference before actually computing it.
811
753
for a in xla_args :
0 commit comments