Skip to content

Commit 006130d

Browse files
ysiraichipytorchmergebot
authored andcommitted
Add test for consistency between meta and CPU devices. (pytorch#138515)
Reference: pytorch#138399 This PR introduces an `OpInfo` test that checks whether running each `out=` operation using meta inputs is consistent with using concrete (e.g. CPU) inputs. More specifically, it tests the case where the output tensors are not of the expected data type. According to the `out=` specification, some operations should error. I have added XFAIL to the set of operations that are currently failing. Pull Request resolved: pytorch#138515 Approved by: https://github.com/ezyang
1 parent 4dd04db commit 006130d

File tree

1 file changed

+196
-0
lines changed

1 file changed

+196
-0
lines changed

test/test_ops.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
parametrize,
6565
run_tests,
6666
set_default_dtype,
67+
skipIfTorchDynamo,
6768
skipIfTorchInductor,
6869
slowTest,
6970
suppress_warnings,
@@ -120,6 +121,121 @@ def reduction_dtype_filter(op):
120121

121122
aten = torch.ops.aten
122123

124+
meta_consistency_out_dtype_mismatch_xfails = {
125+
xfail("abs"),
126+
xfail("addbmm"),
127+
xfail("addmm"),
128+
xfail("addmm", "decomposed"),
129+
xfail("addmv"),
130+
xfail("alias_copy"),
131+
xfail("all"),
132+
xfail("amax"),
133+
xfail("amin"),
134+
xfail("aminmax"),
135+
xfail("any"),
136+
xfail("as_strided_copy"),
137+
xfail("baddbmm"),
138+
xfail("bucketize"),
139+
xfail("ceil"),
140+
xfail("conj_physical"),
141+
xfail("cross"),
142+
xfail("cummax"),
143+
xfail("cummin"),
144+
xfail("diag"),
145+
xfail("diagonal_copy"),
146+
xfail("dot"),
147+
xfail("expand_copy"),
148+
xfail("fft.ihfft2"),
149+
xfail("fft.ihfftn"),
150+
xfail("floor"),
151+
xfail("frac"),
152+
xfail("frexp"),
153+
xfail("geqrf"),
154+
xfail("heaviside"),
155+
xfail("histc"),
156+
xfail("index_add"),
157+
xfail("index_copy"),
158+
xfail("index_select"),
159+
xfail("isin"),
160+
xfail("isneginf"),
161+
xfail("isposinf"),
162+
xfail("kthvalue"),
163+
xfail("lerp"),
164+
xfail("linalg.cross"),
165+
xfail("linalg.eigh"),
166+
xfail("linalg.eigvalsh"),
167+
xfail("linalg.ldl_factor"),
168+
xfail("linalg.ldl_factor_ex"),
169+
xfail("linalg.ldl_solve"),
170+
xfail("linalg.lu"),
171+
xfail("linalg.lu_factor"),
172+
xfail("linalg.lu_factor_ex"),
173+
xfail("linalg.lu_solve"),
174+
xfail("linalg.matrix_power"),
175+
xfail("linalg.qr"),
176+
xfail("linalg.slogdet"),
177+
xfail("linalg.solve"),
178+
xfail("linalg.solve_ex"),
179+
xfail("linalg.solve_triangular"),
180+
xfail("log_softmax"),
181+
xfail("logcumsumexp"),
182+
xfail("lu_solve"),
183+
xfail("lu_unpack"),
184+
xfail("matmul"),
185+
xfail("mean"),
186+
xfail("mm"),
187+
xfail("mode"),
188+
xfail("msort"),
189+
xfail("multinomial"),
190+
xfail("mv"),
191+
xfail("nan_to_num"),
192+
xfail("nanmean"),
193+
xfail("narrow_copy"),
194+
xfail("native_batch_norm"),
195+
xfail("neg"),
196+
xfail("nn.functional.avg_pool3d"),
197+
xfail("nn.functional.gelu"),
198+
xfail("nn.functional.hardshrink"),
199+
xfail("nn.functional.linear"),
200+
xfail("nn.functional.logsigmoid"),
201+
xfail("nn.functional.softplus"),
202+
xfail("nn.functional.softshrink"),
203+
xfail("ormqr"),
204+
xfail("qr"),
205+
xfail("renorm"),
206+
xfail("round"),
207+
xfail("round", "decimals_0"),
208+
xfail("scatter_reduce", "amax"),
209+
xfail("scatter_reduce", "amin"),
210+
xfail("scatter_reduce", "mean"),
211+
xfail("scatter_reduce", "prod"),
212+
xfail("scatter_reduce", "sum"),
213+
xfail("searchsorted"),
214+
xfail("sgn"),
215+
xfail("sign"),
216+
xfail("signbit"),
217+
xfail("slice_scatter"),
218+
xfail("softmax"),
219+
xfail("sort"),
220+
xfail("sparse.sampled_addmm"),
221+
xfail("square"),
222+
xfail("squeeze_copy"),
223+
xfail("t_copy"),
224+
xfail("take"),
225+
xfail("transpose_copy"),
226+
xfail("tril"),
227+
xfail("triu"),
228+
xfail("trunc"),
229+
xfail("unfold_copy"),
230+
xfail("unsqueeze_copy"),
231+
xfail("vdot"),
232+
xfail("view_copy"),
233+
xfail("where"),
234+
# Output has dynamic shape.
235+
# Does not have a meta kernel implementation.
236+
skip("linalg.lstsq"),
237+
}
238+
123239

124240
# Tests that apply to all operators and aren't related to any particular
125241
# system
@@ -1581,6 +1697,86 @@ def test_promotes_int_to_float(self, device, dtype, op):
15811697
f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}."
15821698
)
15831699

1700+
# Checks whether running the operations on both CPU and meta devices raise errors
1701+
# when the output tensors have mismatching data-types (i.e. data-types that are
1702+
# different from the expected one).
1703+
#
1704+
# The idea is that the meta implementations should correctly reflect on the behavior
1705+
# of other concrete devices (e.g. CPU and CUDA).
1706+
@onlyCPU
1707+
@ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,))
1708+
@skipOps(
1709+
"TestCommon",
1710+
"test_meta_consistency_out_dtype_mismatch",
1711+
meta_consistency_out_dtype_mismatch_xfails,
1712+
)
1713+
@skipIfTorchDynamo("meta device runs only on eager")
1714+
def test_meta_consistency_out_dtype_mismatch(self, device, dtype, op):
1715+
samples = op.sample_inputs(device, dtype)
1716+
1717+
for i, sample in enumerate(samples):
1718+
input, args, kwargs = (sample.input, sample.args, sample.kwargs)
1719+
1720+
try:
1721+
# Call the functional version of the operation, using a real device, so that
1722+
# we get the actual expected result.
1723+
expected = op(input, *args, **kwargs)
1724+
1725+
if isinstance(expected, tuple):
1726+
# Some operations return named tuples. However, pytree does not work well
1727+
# with that, so we turn it into a plain tuple.
1728+
expected = tuple(expected)
1729+
except Exception:
1730+
# If that doesn't work out, go to the next sample.
1731+
continue
1732+
1733+
def run_on(dev):
1734+
# Create new outputs in the desired device, with a mismatching data type of
1735+
# the same kind.
1736+
out = pytree.tree_map_only(
1737+
torch.Tensor,
1738+
lambda t: torch.empty_like(t, device=dev, dtype=torch.float64),
1739+
expected,
1740+
)
1741+
1742+
# Move inputs to the desired device.
1743+
arguments = (input, args, kwargs)
1744+
arguments = pytree.tree_map_only(
1745+
torch.Tensor, lambda t: t.to(dev), arguments
1746+
)
1747+
# Also, replace every instance of 'cpu' arguments by whatever the desired
1748+
# device really should be.
1749+
arguments = pytree.tree_map_only(
1750+
torch.device, lambda d: torch.device(dev), arguments
1751+
)
1752+
arguments = pytree.tree_map_only(
1753+
str, lambda v: dev if v == device else v, arguments
1754+
)
1755+
input_, args_, kwargs_ = arguments
1756+
1757+
# Try running the operation, and return the raised error, if any.
1758+
try:
1759+
op(input_, *args_, **kwargs_, out=out)
1760+
except Exception as e:
1761+
return e
1762+
1763+
# Run the operation with the sample arguments on both CPU and meta devices, capturing
1764+
# the raised error, if any.
1765+
device_err = run_on(device)
1766+
meta_err = run_on("meta")
1767+
1768+
# Check whether they disagree on the result.
1769+
#
1770+
# In case there is an inconsistency of whether an error was raised using the real device,
1771+
# but not when using the meta device, we raise a RuntimeError, chaining with the captured
1772+
# one.
1773+
#
1774+
# We could just assertEquals here, but chaining the errors is more informative.
1775+
if device_err is None and meta_err is not None:
1776+
raise RuntimeError(f"{device} didn't fail, but meta did.") from meta_err
1777+
elif device_err is not None and meta_err is None:
1778+
raise RuntimeError(f"{device} failed, but meta didn't.") from device_err
1779+
15841780

15851781
@unMarkDynamoStrictTest
15861782
class TestCompositeCompliance(TestCase):

0 commit comments

Comments
 (0)