|
64 | 64 | parametrize, |
65 | 65 | run_tests, |
66 | 66 | set_default_dtype, |
| 67 | + skipIfTorchDynamo, |
67 | 68 | skipIfTorchInductor, |
68 | 69 | slowTest, |
69 | 70 | suppress_warnings, |
@@ -120,6 +121,121 @@ def reduction_dtype_filter(op): |
120 | 121 |
|
121 | 122 | aten = torch.ops.aten |
122 | 123 |
|
| 124 | +meta_consistency_out_dtype_mismatch_xfails = { |
| 125 | + xfail("abs"), |
| 126 | + xfail("addbmm"), |
| 127 | + xfail("addmm"), |
| 128 | + xfail("addmm", "decomposed"), |
| 129 | + xfail("addmv"), |
| 130 | + xfail("alias_copy"), |
| 131 | + xfail("all"), |
| 132 | + xfail("amax"), |
| 133 | + xfail("amin"), |
| 134 | + xfail("aminmax"), |
| 135 | + xfail("any"), |
| 136 | + xfail("as_strided_copy"), |
| 137 | + xfail("baddbmm"), |
| 138 | + xfail("bucketize"), |
| 139 | + xfail("ceil"), |
| 140 | + xfail("conj_physical"), |
| 141 | + xfail("cross"), |
| 142 | + xfail("cummax"), |
| 143 | + xfail("cummin"), |
| 144 | + xfail("diag"), |
| 145 | + xfail("diagonal_copy"), |
| 146 | + xfail("dot"), |
| 147 | + xfail("expand_copy"), |
| 148 | + xfail("fft.ihfft2"), |
| 149 | + xfail("fft.ihfftn"), |
| 150 | + xfail("floor"), |
| 151 | + xfail("frac"), |
| 152 | + xfail("frexp"), |
| 153 | + xfail("geqrf"), |
| 154 | + xfail("heaviside"), |
| 155 | + xfail("histc"), |
| 156 | + xfail("index_add"), |
| 157 | + xfail("index_copy"), |
| 158 | + xfail("index_select"), |
| 159 | + xfail("isin"), |
| 160 | + xfail("isneginf"), |
| 161 | + xfail("isposinf"), |
| 162 | + xfail("kthvalue"), |
| 163 | + xfail("lerp"), |
| 164 | + xfail("linalg.cross"), |
| 165 | + xfail("linalg.eigh"), |
| 166 | + xfail("linalg.eigvalsh"), |
| 167 | + xfail("linalg.ldl_factor"), |
| 168 | + xfail("linalg.ldl_factor_ex"), |
| 169 | + xfail("linalg.ldl_solve"), |
| 170 | + xfail("linalg.lu"), |
| 171 | + xfail("linalg.lu_factor"), |
| 172 | + xfail("linalg.lu_factor_ex"), |
| 173 | + xfail("linalg.lu_solve"), |
| 174 | + xfail("linalg.matrix_power"), |
| 175 | + xfail("linalg.qr"), |
| 176 | + xfail("linalg.slogdet"), |
| 177 | + xfail("linalg.solve"), |
| 178 | + xfail("linalg.solve_ex"), |
| 179 | + xfail("linalg.solve_triangular"), |
| 180 | + xfail("log_softmax"), |
| 181 | + xfail("logcumsumexp"), |
| 182 | + xfail("lu_solve"), |
| 183 | + xfail("lu_unpack"), |
| 184 | + xfail("matmul"), |
| 185 | + xfail("mean"), |
| 186 | + xfail("mm"), |
| 187 | + xfail("mode"), |
| 188 | + xfail("msort"), |
| 189 | + xfail("multinomial"), |
| 190 | + xfail("mv"), |
| 191 | + xfail("nan_to_num"), |
| 192 | + xfail("nanmean"), |
| 193 | + xfail("narrow_copy"), |
| 194 | + xfail("native_batch_norm"), |
| 195 | + xfail("neg"), |
| 196 | + xfail("nn.functional.avg_pool3d"), |
| 197 | + xfail("nn.functional.gelu"), |
| 198 | + xfail("nn.functional.hardshrink"), |
| 199 | + xfail("nn.functional.linear"), |
| 200 | + xfail("nn.functional.logsigmoid"), |
| 201 | + xfail("nn.functional.softplus"), |
| 202 | + xfail("nn.functional.softshrink"), |
| 203 | + xfail("ormqr"), |
| 204 | + xfail("qr"), |
| 205 | + xfail("renorm"), |
| 206 | + xfail("round"), |
| 207 | + xfail("round", "decimals_0"), |
| 208 | + xfail("scatter_reduce", "amax"), |
| 209 | + xfail("scatter_reduce", "amin"), |
| 210 | + xfail("scatter_reduce", "mean"), |
| 211 | + xfail("scatter_reduce", "prod"), |
| 212 | + xfail("scatter_reduce", "sum"), |
| 213 | + xfail("searchsorted"), |
| 214 | + xfail("sgn"), |
| 215 | + xfail("sign"), |
| 216 | + xfail("signbit"), |
| 217 | + xfail("slice_scatter"), |
| 218 | + xfail("softmax"), |
| 219 | + xfail("sort"), |
| 220 | + xfail("sparse.sampled_addmm"), |
| 221 | + xfail("square"), |
| 222 | + xfail("squeeze_copy"), |
| 223 | + xfail("t_copy"), |
| 224 | + xfail("take"), |
| 225 | + xfail("transpose_copy"), |
| 226 | + xfail("tril"), |
| 227 | + xfail("triu"), |
| 228 | + xfail("trunc"), |
| 229 | + xfail("unfold_copy"), |
| 230 | + xfail("unsqueeze_copy"), |
| 231 | + xfail("vdot"), |
| 232 | + xfail("view_copy"), |
| 233 | + xfail("where"), |
| 234 | + # Output has dynamic shape. |
| 235 | + # Does not have a meta kernel implementation. |
| 236 | + skip("linalg.lstsq"), |
| 237 | +} |
| 238 | + |
123 | 239 |
|
124 | 240 | # Tests that apply to all operators and aren't related to any particular |
125 | 241 | # system |
@@ -1581,6 +1697,86 @@ def test_promotes_int_to_float(self, device, dtype, op): |
1581 | 1697 | f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}." |
1582 | 1698 | ) |
1583 | 1699 |
|
| 1700 | + # Checks whether running the operations on both CPU and meta devices raise errors |
| 1701 | + # when the output tensors have mismatching data-types (i.e. data-types that are |
| 1702 | + # different from the expected one). |
| 1703 | + # |
| 1704 | + # The idea is that the meta implementations should correctly reflect on the behavior |
| 1705 | + # of other concrete devices (e.g. CPU and CUDA). |
| 1706 | + @onlyCPU |
| 1707 | + @ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,)) |
| 1708 | + @skipOps( |
| 1709 | + "TestCommon", |
| 1710 | + "test_meta_consistency_out_dtype_mismatch", |
| 1711 | + meta_consistency_out_dtype_mismatch_xfails, |
| 1712 | + ) |
| 1713 | + @skipIfTorchDynamo("meta device runs only on eager") |
| 1714 | + def test_meta_consistency_out_dtype_mismatch(self, device, dtype, op): |
| 1715 | + samples = op.sample_inputs(device, dtype) |
| 1716 | + |
| 1717 | + for i, sample in enumerate(samples): |
| 1718 | + input, args, kwargs = (sample.input, sample.args, sample.kwargs) |
| 1719 | + |
| 1720 | + try: |
| 1721 | + # Call the functional version of the operation, using a real device, so that |
| 1722 | + # we get the actual expected result. |
| 1723 | + expected = op(input, *args, **kwargs) |
| 1724 | + |
| 1725 | + if isinstance(expected, tuple): |
| 1726 | + # Some operations return named tuples. However, pytree does not work well |
| 1727 | + # with that, so we turn it into a plain tuple. |
| 1728 | + expected = tuple(expected) |
| 1729 | + except Exception: |
| 1730 | + # If that doesn't work out, go to the next sample. |
| 1731 | + continue |
| 1732 | + |
| 1733 | + def run_on(dev): |
| 1734 | + # Create new outputs in the desired device, with a mismatching data type of |
| 1735 | + # the same kind. |
| 1736 | + out = pytree.tree_map_only( |
| 1737 | + torch.Tensor, |
| 1738 | + lambda t: torch.empty_like(t, device=dev, dtype=torch.float64), |
| 1739 | + expected, |
| 1740 | + ) |
| 1741 | + |
| 1742 | + # Move inputs to the desired device. |
| 1743 | + arguments = (input, args, kwargs) |
| 1744 | + arguments = pytree.tree_map_only( |
| 1745 | + torch.Tensor, lambda t: t.to(dev), arguments |
| 1746 | + ) |
| 1747 | + # Also, replace every instance of 'cpu' arguments by whatever the desired |
| 1748 | + # device really should be. |
| 1749 | + arguments = pytree.tree_map_only( |
| 1750 | + torch.device, lambda d: torch.device(dev), arguments |
| 1751 | + ) |
| 1752 | + arguments = pytree.tree_map_only( |
| 1753 | + str, lambda v: dev if v == device else v, arguments |
| 1754 | + ) |
| 1755 | + input_, args_, kwargs_ = arguments |
| 1756 | + |
| 1757 | + # Try running the operation, and return the raised error, if any. |
| 1758 | + try: |
| 1759 | + op(input_, *args_, **kwargs_, out=out) |
| 1760 | + except Exception as e: |
| 1761 | + return e |
| 1762 | + |
| 1763 | + # Run the operation with the sample arguments on both CPU and meta devices, capturing |
| 1764 | + # the raised error, if any. |
| 1765 | + device_err = run_on(device) |
| 1766 | + meta_err = run_on("meta") |
| 1767 | + |
| 1768 | + # Check whether they disagree on the result. |
| 1769 | + # |
| 1770 | + # In case there is an inconsistency of whether an error was raised using the real device, |
| 1771 | + # but not when using the meta device, we raise a RuntimeError, chaining with the captured |
| 1772 | + # one. |
| 1773 | + # |
| 1774 | + # We could just assertEquals here, but chaining the errors is more informative. |
| 1775 | + if device_err is None and meta_err is not None: |
| 1776 | + raise RuntimeError(f"{device} didn't fail, but meta did.") from meta_err |
| 1777 | + elif device_err is not None and meta_err is None: |
| 1778 | + raise RuntimeError(f"{device} failed, but meta didn't.") from device_err |
| 1779 | + |
1584 | 1780 |
|
1585 | 1781 | @unMarkDynamoStrictTest |
1586 | 1782 | class TestCompositeCompliance(TestCase): |
|
0 commit comments