|
26 | 26 | from_array_ml_dtypes, |
27 | 27 | dtype_to_tensor_dtype, |
28 | 28 | string_diff, |
| 29 | + rename_dynamic_dimensions, |
| 30 | + rename_dynamic_expression, |
29 | 31 | ) |
30 | 32 |
|
31 | 33 | TFLOAT = onnx.TensorProto.FLOAT |
@@ -241,6 +243,89 @@ def test_string_signature(self): |
241 | 243 | def test_make_hash(self): |
242 | 244 | self.assertIsInstance(make_hash([]), str) |
243 | 245 |
|
| 246 | + def test_string_type_one(self): |
| 247 | + self.assertEqual(string_type(None), "None") |
| 248 | + self.assertEqual(string_type([4]), "#1[int]") |
| 249 | + self.assertEqual(string_type((4, 5)), "(int,int)") |
| 250 | + self.assertEqual(string_type([4] * 100), "#100[int,...]") |
| 251 | + self.assertEqual(string_type((4,) * 100), "#100(int,...)") |
| 252 | + |
| 253 | + def test_string_type_at(self): |
| 254 | + self.assertEqual(string_type(None), "None") |
| 255 | + a = np.array([4, 5], dtype=np.float32) |
| 256 | + t = torch.tensor([4, 5], dtype=torch.float32) |
| 257 | + self.assertEqual(string_type([a]), "#1[A1r1]") |
| 258 | + self.assertEqual(string_type([t]), "#1[T1r1]") |
| 259 | + self.assertEqual(string_type((a,)), "(A1r1,)") |
| 260 | + self.assertEqual(string_type((t,)), "(T1r1,)") |
| 261 | + self.assertEqual(string_type([a] * 100), "#100[A1r1,...]") |
| 262 | + self.assertEqual(string_type([t] * 100), "#100[T1r1,...]") |
| 263 | + self.assertEqual(string_type((a,) * 100), "#100(A1r1,...)") |
| 264 | + self.assertEqual(string_type((t,) * 100), "#100(T1r1,...)") |
| 265 | + |
| 266 | + def test_string_type_at_with_shape(self): |
| 267 | + self.assertEqual(string_type(None), "None") |
| 268 | + a = np.array([4, 5], dtype=np.float32) |
| 269 | + t = torch.tensor([4, 5], dtype=torch.float32) |
| 270 | + self.assertEqual(string_type([a], with_shape=True), "#1[A1s2]") |
| 271 | + self.assertEqual(string_type([t], with_shape=True), "#1[T1s2]") |
| 272 | + self.assertEqual(string_type((a,), with_shape=True), "(A1s2,)") |
| 273 | + self.assertEqual(string_type((t,), with_shape=True), "(T1s2,)") |
| 274 | + self.assertEqual(string_type([a] * 100, with_shape=True), "#100[A1s2,...]") |
| 275 | + self.assertEqual(string_type([t] * 100, with_shape=True), "#100[T1s2,...]") |
| 276 | + self.assertEqual(string_type((a,) * 100, with_shape=True), "#100(A1s2,...)") |
| 277 | + self.assertEqual(string_type((t,) * 100, with_shape=True), "#100(T1s2,...)") |
| 278 | + |
| 279 | + def test_string_type_at_with_shape_min_max(self): |
| 280 | + self.assertEqual(string_type(None), "None") |
| 281 | + a = np.array([4, 5], dtype=np.float32) |
| 282 | + t = torch.tensor([4, 5], dtype=torch.float32) |
| 283 | + self.assertEqual( |
| 284 | + string_type([a], with_shape=True, with_min_max=True), "#1[A1s2[4.0,5.0:A4.5]]" |
| 285 | + ) |
| 286 | + self.assertEqual( |
| 287 | + string_type([t], with_shape=True, with_min_max=True), "#1[T1s2[4.0,5.0:A4.5]]" |
| 288 | + ) |
| 289 | + self.assertEqual( |
| 290 | + string_type((a,), with_shape=True, with_min_max=True), "(A1s2[4.0,5.0:A4.5],)" |
| 291 | + ) |
| 292 | + self.assertEqual( |
| 293 | + string_type((t,), with_shape=True, with_min_max=True), "(T1s2[4.0,5.0:A4.5],)" |
| 294 | + ) |
| 295 | + self.assertEqual( |
| 296 | + string_type([a] * 100, with_shape=True, with_min_max=True), |
| 297 | + "#100[A1s2[4.0,5.0:A4.5],...]", |
| 298 | + ) |
| 299 | + self.assertEqual( |
| 300 | + string_type([t] * 100, with_shape=True, with_min_max=True), |
| 301 | + "#100[T1s2[4.0,5.0:A4.5],...]", |
| 302 | + ) |
| 303 | + self.assertEqual( |
| 304 | + string_type((a,) * 100, with_shape=True, with_min_max=True), |
| 305 | + "#100(A1s2[4.0,5.0:A4.5],...)", |
| 306 | + ) |
| 307 | + self.assertEqual( |
| 308 | + string_type((t,) * 100, with_shape=True, with_min_max=True), |
| 309 | + "#100(T1s2[4.0,5.0:A4.5],...)", |
| 310 | + ) |
| 311 | + |
| 312 | + def test_pretty_onnx_att(self): |
| 313 | + node = oh.make_node("Cast", ["xm2c"], ["xm2"], to=1) |
| 314 | + pretty_onnx(node.attribute[0]) |
| 315 | + |
| 316 | + def test_rename_dimension(self): |
| 317 | + res = rename_dynamic_dimensions( |
| 318 | + {"a": {"B", "C"}}, |
| 319 | + { |
| 320 | + "B", |
| 321 | + }, |
| 322 | + ) |
| 323 | + self.assertEqual(res, {"B": "B", "a": "B", "C": "B"}) |
| 324 | + |
| 325 | + def test_rename_dynamic_expression(self): |
| 326 | + text = rename_dynamic_expression("a * 10 - a", {"a": "x"}) |
| 327 | + self.assertEqual(text, "x * 10 - x") |
| 328 | + |
244 | 329 |
|
245 | 330 | if __name__ == "__main__": |
246 | 331 | unittest.main(verbosity=2) |
0 commit comments