|
1 | 1 | # flake8: noqa |
2 | 2 | import pytest |
3 | 3 | import sys |
4 | | -from unittest.mock import MagicMock |
| 4 | +from unittest.mock import MagicMock, call |
| 5 | +from sqlite_utils.utils import sqlite3 |
5 | 6 |
|
6 | 7 |
|
7 | 8 | def test_register_function(fresh_db): |
@@ -31,36 +32,41 @@ def to_lower(s): |
31 | 32 | assert result == "bob" |
32 | 33 |
|
33 | 34 |
|
34 | | -@pytest.mark.skipif( |
35 | | - sys.version_info < (3, 8), reason="deterministic=True was added in Python 3.8" |
36 | | -) |
37 | | -@pytest.mark.parametrize( |
38 | | - "fake_sqlite_version,should_use_deterministic", |
39 | | - ( |
40 | | - ("3.36.0", True), |
41 | | - ("3.8.3", True), |
42 | | - ("3.8.2", False), |
43 | | - ), |
44 | | -) |
45 | | -def test_register_function_deterministic_registered( |
46 | | - fresh_db, fake_sqlite_version, should_use_deterministic |
47 | | -): |
| 35 | +def test_register_function_deterministic_tries_again_if_exception_raised(fresh_db): |
48 | 36 | fresh_db.conn = MagicMock() |
49 | 37 | fresh_db.conn.create_function = MagicMock() |
50 | | - fresh_db.conn.execute().fetchall.return_value = [(fake_sqlite_version,)] |
51 | 38 |
|
52 | 39 | @fresh_db.register_function(deterministic=True) |
53 | 40 | def to_lower_2(s): |
54 | 41 | return s.lower() |
55 | 42 |
|
56 | | - expected_kwargs = {} |
57 | | - if should_use_deterministic: |
58 | | - expected_kwargs = dict(deterministic=True) |
59 | | - |
60 | 43 | fresh_db.conn.create_function.assert_called_with( |
61 | | - "to_lower_2", 1, to_lower_2, **expected_kwargs |
| 44 | + "to_lower_2", 1, to_lower_2, deterministic=True |
62 | 45 | ) |
63 | 46 |
|
| 47 | + first = True |
| 48 | + |
| 49 | + def side_effect(*args, **kwargs): |
| 50 | + # Raise exception only first time this is called |
| 51 | + nonlocal first |
| 52 | + if first: |
| 53 | + first = False |
| 54 | + raise sqlite3.NotSupportedError() |
| 55 | + |
| 56 | + # But if sqlite3.NotSupportedError is raised, it tries again |
| 57 | + fresh_db.conn.create_function.reset_mock() |
| 58 | + fresh_db.conn.create_function.side_effect = side_effect |
| 59 | + |
| 60 | + @fresh_db.register_function(deterministic=True) |
| 61 | + def to_lower_3(s): |
| 62 | + return s.lower() |
| 63 | + |
| 64 | + # Should have been called once with deterministic=True and once without |
| 65 | + assert fresh_db.conn.create_function.call_args_list == [ |
| 66 | + call("to_lower_3", 1, to_lower_3, deterministic=True), |
| 67 | + call("to_lower_3", 1, to_lower_3), |
| 68 | + ] |
| 69 | + |
64 | 70 |
|
65 | 71 | def test_register_function_replace(fresh_db): |
66 | 72 | @fresh_db.register_function() |
|
0 commit comments