|
| 1 | +import io |
| 2 | +import sys |
1 | 3 | import unittest |
2 | 4 | import decimal |
| 5 | +import simplejson as json |
| 6 | +import singer.messages as messages |
3 | 7 | from singer import transform |
4 | 8 | from singer.transform import * |
5 | 9 |
|
6 | | - |
7 | 10 | class TestTransform(unittest.TestCase): |
8 | 11 | def test_integer_transform(self): |
9 | 12 | schema = {'type': 'integer'} |
@@ -486,3 +489,58 @@ def test_pattern_properties_match_multiple(self): |
486 | 489 | dict_value = {"name": "chicken", "unit_cost": 1.45, "SKU": '123456'} |
487 | 490 | expected = dict(dict_value) |
488 | 491 | self.assertEqual(expected, transform(dict_value, schema)) |
| 492 | + |
| 493 | +class DummyMessage: |
| 494 | + """A dummy message object with an asdict() method.""" |
| 495 | + def __init__(self, value): |
| 496 | + self.value = value |
| 497 | + |
| 498 | + def asdict(self): |
| 499 | + return {"value": self.value} |
| 500 | + |
| 501 | + |
| 502 | +class TestAllowNan(unittest.TestCase): |
| 503 | + """Unit tests for allow_nan support in singer.messages.""" |
| 504 | + |
| 505 | + def test_format_message_allow_nan_true(self): |
| 506 | + """Should serialize NaN successfully when allow_nan=True.""" |
| 507 | + msg = DummyMessage(float("nan")) |
| 508 | + result = messages.format_message(msg, allow_nan=True) |
| 509 | + |
| 510 | + # The output JSON should contain NaN literal (not quoted) |
| 511 | + self.assertIn("NaN", result) |
| 512 | + |
| 513 | + # Replace NaN with null to make it valid JSON for parsing check |
| 514 | + json.loads(result.replace("NaN", "null")) |
| 515 | + |
| 516 | + def test_format_message_allow_nan_false(self): |
| 517 | + """Should raise ValueError when allow_nan=False and value is NaN.""" |
| 518 | + msg = DummyMessage(float("nan")) |
| 519 | + with self.assertRaises(ValueError): |
| 520 | + messages.format_message(msg, allow_nan=False) |
| 521 | + |
| 522 | + def test_write_message_allow_nan_true(self): |
| 523 | + """Should write to stdout successfully when allow_nan=True.""" |
| 524 | + msg = DummyMessage(float("nan")) |
| 525 | + fake_stdout = io.StringIO() |
| 526 | + original_stdout = sys.stdout |
| 527 | + sys.stdout = fake_stdout |
| 528 | + try: |
| 529 | + messages.write_message(msg, allow_nan=True) |
| 530 | + output = fake_stdout.getvalue() |
| 531 | + self.assertIn("NaN", output) |
| 532 | + self.assertTrue(output.endswith("\n")) |
| 533 | + finally: |
| 534 | + sys.stdout = original_stdout |
| 535 | + |
| 536 | + def test_write_message_allow_nan_false(self): |
| 537 | + """Should raise ValueError when allow_nan=False and message has NaN.""" |
| 538 | + msg = DummyMessage(float("nan")) |
| 539 | + fake_stdout = io.StringIO() |
| 540 | + original_stdout = sys.stdout |
| 541 | + sys.stdout = fake_stdout |
| 542 | + try: |
| 543 | + with self.assertRaises(ValueError): |
| 544 | + messages.write_message(msg, allow_nan=False) |
| 545 | + finally: |
| 546 | + sys.stdout = original_stdout |
0 commit comments