|
64 | 64 | parametrize, |
65 | 65 | run_tests, |
66 | 66 | set_default_dtype, |
67 | | - skipIfTorchDynamo, |
68 | 67 | skipIfTorchInductor, |
69 | 68 | slowTest, |
70 | 69 | suppress_warnings, |
@@ -121,121 +120,6 @@ def reduction_dtype_filter(op): |
121 | 120 |
|
122 | 121 | aten = torch.ops.aten |
123 | 122 |
|
124 | | -meta_consistency_out_dtype_mismatch_xfails = { |
125 | | - xfail("abs"), |
126 | | - xfail("addbmm"), |
127 | | - xfail("addmm"), |
128 | | - xfail("addmm", "decomposed"), |
129 | | - xfail("addmv"), |
130 | | - xfail("alias_copy"), |
131 | | - xfail("all"), |
132 | | - xfail("amax"), |
133 | | - xfail("amin"), |
134 | | - xfail("aminmax"), |
135 | | - xfail("any"), |
136 | | - xfail("as_strided_copy"), |
137 | | - xfail("baddbmm"), |
138 | | - xfail("bucketize"), |
139 | | - xfail("ceil"), |
140 | | - xfail("conj_physical"), |
141 | | - xfail("cross"), |
142 | | - xfail("cummax"), |
143 | | - xfail("cummin"), |
144 | | - xfail("diag"), |
145 | | - xfail("diagonal_copy"), |
146 | | - xfail("dot"), |
147 | | - xfail("expand_copy"), |
148 | | - xfail("fft.ihfft2"), |
149 | | - xfail("fft.ihfftn"), |
150 | | - xfail("floor"), |
151 | | - xfail("frac"), |
152 | | - xfail("frexp"), |
153 | | - xfail("geqrf"), |
154 | | - xfail("heaviside"), |
155 | | - xfail("histc"), |
156 | | - xfail("index_add"), |
157 | | - xfail("index_copy"), |
158 | | - xfail("index_select"), |
159 | | - xfail("isin"), |
160 | | - xfail("isneginf"), |
161 | | - xfail("isposinf"), |
162 | | - xfail("kthvalue"), |
163 | | - xfail("lerp"), |
164 | | - xfail("linalg.cross"), |
165 | | - xfail("linalg.eigh"), |
166 | | - xfail("linalg.eigvalsh"), |
167 | | - xfail("linalg.ldl_factor"), |
168 | | - xfail("linalg.ldl_factor_ex"), |
169 | | - xfail("linalg.ldl_solve"), |
170 | | - xfail("linalg.lu"), |
171 | | - xfail("linalg.lu_factor"), |
172 | | - xfail("linalg.lu_factor_ex"), |
173 | | - xfail("linalg.lu_solve"), |
174 | | - xfail("linalg.matrix_power"), |
175 | | - xfail("linalg.qr"), |
176 | | - xfail("linalg.slogdet"), |
177 | | - xfail("linalg.solve"), |
178 | | - xfail("linalg.solve_ex"), |
179 | | - xfail("linalg.solve_triangular"), |
180 | | - xfail("log_softmax"), |
181 | | - xfail("logcumsumexp"), |
182 | | - xfail("lu_solve"), |
183 | | - xfail("lu_unpack"), |
184 | | - xfail("matmul"), |
185 | | - xfail("mean"), |
186 | | - xfail("mm"), |
187 | | - xfail("mode"), |
188 | | - xfail("msort"), |
189 | | - xfail("multinomial"), |
190 | | - xfail("mv"), |
191 | | - xfail("nan_to_num"), |
192 | | - xfail("nanmean"), |
193 | | - xfail("narrow_copy"), |
194 | | - xfail("native_batch_norm"), |
195 | | - xfail("neg"), |
196 | | - xfail("nn.functional.avg_pool3d"), |
197 | | - xfail("nn.functional.gelu"), |
198 | | - xfail("nn.functional.hardshrink"), |
199 | | - xfail("nn.functional.linear"), |
200 | | - xfail("nn.functional.logsigmoid"), |
201 | | - xfail("nn.functional.softplus"), |
202 | | - xfail("nn.functional.softshrink"), |
203 | | - xfail("ormqr"), |
204 | | - xfail("qr"), |
205 | | - xfail("renorm"), |
206 | | - xfail("round"), |
207 | | - xfail("round", "decimals_0"), |
208 | | - xfail("scatter_reduce", "amax"), |
209 | | - xfail("scatter_reduce", "amin"), |
210 | | - xfail("scatter_reduce", "mean"), |
211 | | - xfail("scatter_reduce", "prod"), |
212 | | - xfail("scatter_reduce", "sum"), |
213 | | - xfail("searchsorted"), |
214 | | - xfail("sgn"), |
215 | | - xfail("sign"), |
216 | | - xfail("signbit"), |
217 | | - xfail("slice_scatter"), |
218 | | - xfail("softmax"), |
219 | | - xfail("sort"), |
220 | | - xfail("sparse.sampled_addmm"), |
221 | | - xfail("square"), |
222 | | - xfail("squeeze_copy"), |
223 | | - xfail("t_copy"), |
224 | | - xfail("take"), |
225 | | - xfail("transpose_copy"), |
226 | | - xfail("tril"), |
227 | | - xfail("triu"), |
228 | | - xfail("trunc"), |
229 | | - xfail("unfold_copy"), |
230 | | - xfail("unsqueeze_copy"), |
231 | | - xfail("vdot"), |
232 | | - xfail("view_copy"), |
233 | | - xfail("where"), |
234 | | - # Output has dynamic shape. |
235 | | - # Does not have a meta kernel implementation. |
236 | | - skip("linalg.lstsq"), |
237 | | -} |
238 | | - |
239 | 123 |
|
240 | 124 | # Tests that apply to all operators and aren't related to any particular |
241 | 125 | # system |
@@ -1697,86 +1581,6 @@ def test_promotes_int_to_float(self, device, dtype, op): |
1697 | 1581 | f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}." |
1698 | 1582 | ) |
1699 | 1583 |
|
1700 | | - # Checks whether running the operations on both CPU and meta devices raise errors |
1701 | | - # when the output tensors have mismatching data-types (i.e. data-types that are |
1702 | | - # different from the expected one). |
1703 | | - # |
1704 | | - # The idea is that the meta implementations should correctly reflect on the behavior |
1705 | | - # of other concrete devices (e.g. CPU and CUDA). |
1706 | | - @onlyCPU |
1707 | | - @ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,)) |
1708 | | - @skipOps( |
1709 | | - "TestCommon", |
1710 | | - "test_meta_consistency_out_dtype_mismatch", |
1711 | | - meta_consistency_out_dtype_mismatch_xfails, |
1712 | | - ) |
1713 | | - @skipIfTorchDynamo("meta device runs only on eager") |
1714 | | - def test_meta_consistency_out_dtype_mismatch(self, device, dtype, op): |
1715 | | - samples = op.sample_inputs(device, dtype) |
1716 | | - |
1717 | | - for i, sample in enumerate(samples): |
1718 | | - input, args, kwargs = (sample.input, sample.args, sample.kwargs) |
1719 | | - |
1720 | | - try: |
1721 | | - # Call the functional version of the operation, using a real device, so that |
1722 | | - # we get the actual expected result. |
1723 | | - expected = op(input, *args, **kwargs) |
1724 | | - |
1725 | | - if isinstance(expected, tuple): |
1726 | | - # Some operations return named tuples. However, pytree does not work well |
1727 | | - # with that, so we turn it into a plain tuple. |
1728 | | - expected = tuple(expected) |
1729 | | - except Exception: |
1730 | | - # If that doesn't work out, go to the next sample. |
1731 | | - continue |
1732 | | - |
1733 | | - def run_on(dev): |
1734 | | - # Create new outputs in the desired device, with a mismatching data type of |
1735 | | - # the same kind. |
1736 | | - out = pytree.tree_map_only( |
1737 | | - torch.Tensor, |
1738 | | - lambda t: torch.empty_like(t, device=dev, dtype=torch.float64), |
1739 | | - expected, |
1740 | | - ) |
1741 | | - |
1742 | | - # Move inputs to the desired device. |
1743 | | - arguments = (input, args, kwargs) |
1744 | | - arguments = pytree.tree_map_only( |
1745 | | - torch.Tensor, lambda t: t.to(dev), arguments |
1746 | | - ) |
1747 | | - # Also, replace every instance of 'cpu' arguments by whatever the desired |
1748 | | - # device really should be. |
1749 | | - arguments = pytree.tree_map_only( |
1750 | | - torch.device, lambda d: torch.device(dev), arguments |
1751 | | - ) |
1752 | | - arguments = pytree.tree_map_only( |
1753 | | - str, lambda v: dev if v == device else v, arguments |
1754 | | - ) |
1755 | | - input_, args_, kwargs_ = arguments |
1756 | | - |
1757 | | - # Try running the operation, and return the raised error, if any. |
1758 | | - try: |
1759 | | - op(input_, *args_, **kwargs_, out=out) |
1760 | | - except Exception as e: |
1761 | | - return e |
1762 | | - |
1763 | | - # Run the operation with the sample arguments on both CPU and meta devices, capturing |
1764 | | - # the raised error, if any. |
1765 | | - device_err = run_on(device) |
1766 | | - meta_err = run_on("meta") |
1767 | | - |
1768 | | - # Check whether they disagree on the result. |
1769 | | - # |
1770 | | - # In case there is an inconsistency of whether an error was raised using the real device, |
1771 | | - # but not when using the meta device, we raise a RuntimeError, chaining with the captured |
1772 | | - # one. |
1773 | | - # |
1774 | | - # We could just assertEquals here, but chaining the errors is more informative. |
1775 | | - if device_err is None and meta_err is not None: |
1776 | | - raise RuntimeError(f"{device} didn't fail, but meta did.") from meta_err |
1777 | | - elif device_err is not None and meta_err is None: |
1778 | | - raise RuntimeError(f"{device} failed, but meta didn't.") from device_err |
1779 | | - |
1780 | 1584 |
|
1781 | 1585 | @unMarkDynamoStrictTest |
1782 | 1586 | class TestCompositeCompliance(TestCase): |
|
0 commit comments