|
1 | 1 | import time |
2 | 2 | from random import randint |
3 | 3 | from unittest import TestCase |
| 4 | +from unittest.mock import patch |
4 | 5 |
|
5 | 6 | from ._catches import CACHES |
6 | 7 |
|
@@ -39,22 +40,57 @@ def echo3(x): |
39 | 40 | echo3(val3) |
40 | 41 |
|
41 | 42 | # 验证缓存命中 |
42 | | - self.assertEqual(echo1(val1), val1) |
43 | | - self.assertEqual(echo2(val2), val2) |
44 | | - self.assertEqual(echo3(val3), val3) |
| 43 | + with patch.object(cache, "get", return_value=cache.serialize(val1)) as mock_get: |
| 44 | + with patch.object(cache, "put") as mock_put: |
| 45 | + self.assertEqual(echo1(val1), val1) |
| 46 | + mock_get.assert_called_once() |
| 47 | + mock_put.assert_not_called() |
| 48 | + |
| 49 | + with patch.object(cache, "get", return_value=cache.serialize(val2)) as mock_get: |
| 50 | + with patch.object(cache, "put") as mock_put: |
| 51 | + self.assertEqual(echo2(val2), val2) |
| 52 | + mock_get.assert_called_once() |
| 53 | + mock_put.assert_not_called() |
| 54 | + |
| 55 | + with patch.object(cache, "get", return_value=cache.serialize(val3)) as mock_get: |
| 56 | + with patch.object(cache, "put") as mock_put: |
| 57 | + self.assertEqual(echo3(val3), val3) |
| 58 | + mock_get.assert_called_once() |
| 59 | + mock_put.assert_not_called() |
45 | 60 |
|
46 | 61 | # 等待超过最小 TTL 时间 |
47 | 62 | time.sleep(min(ttl_values) + 1) |
48 | 63 |
|
49 | 64 | # 验证已过期的缓存 |
50 | | - self.assertNotEqual(echo1(val1), val1) |
51 | | - self.assertEqual(echo2(val2), val2) # 这个应该还未过期 |
52 | | - self.assertEqual(echo3(val3), val3) # 这个应该还未过期 |
| 65 | + with patch.object(cache, "put") as mock_put: |
| 66 | + self.assertEqual(echo1(val1), val1) # 应该触发重新计算 |
| 67 | + mock_put.assert_called_once() # 确认缓存已过期并重新写入 |
| 68 | + |
| 69 | + # 验证其他缓存是否未过期 |
| 70 | + with patch.object(cache, "get", return_value=cache.serialize(val2)) as mock_get: |
| 71 | + with patch.object(cache, "put") as mock_put: |
| 72 | + self.assertEqual(echo2(val2), val2) # 这个应该还未过期 |
| 73 | + mock_get.assert_called_once() |
| 74 | + mock_put.assert_not_called() |
| 75 | + |
| 76 | + with patch.object(cache, "get", return_value=cache.serialize(val3)) as mock_get: |
| 77 | + with patch.object(cache, "put") as mock_put: |
| 78 | + self.assertEqual(echo3(val3), val3) # 这个应该还未过期 |
| 79 | + mock_get.assert_called_once() |
| 80 | + mock_put.assert_not_called() |
53 | 81 |
|
54 | 82 | # 等待超过所有 TTL 时间 |
55 | 83 | time.sleep(max(ttl_values) + 1) |
56 | 84 |
|
57 | 85 | # 验证所有缓存都已过期 |
58 | | - self.assertNotEqual(echo1(val1), val1) |
59 | | - self.assertNotEqual(echo2(val2), val2) |
60 | | - self.assertNotEqual(echo3(val3), val3) |
| 86 | + with patch.object(cache, "put") as mock_put: |
| 87 | + self.assertEqual(echo1(val1), val1) # 应该触发重新计算 |
| 88 | + mock_put.assert_called_once() # 确认缓存已过期并重新写入 |
| 89 | + |
| 90 | + with patch.object(cache, "put") as mock_put: |
| 91 | + self.assertEqual(echo2(val2), val2) # 应该触发重新计算 |
| 92 | + mock_put.assert_called_once() |
| 93 | + |
| 94 | + with patch.object(cache, "put") as mock_put: |
| 95 | + self.assertEqual(echo3(val3), val3) # 应该触发重新计算 |
| 96 | + mock_put.assert_called_once() |
0 commit comments