| 
5 | 5 | import tempfile  | 
6 | 6 | import os  | 
7 | 7 | import json  | 
 | 8 | +import asyncio  | 
 | 9 | +from unittest.mock import patch, AsyncMock, MagicMock  | 
8 | 10 | 
 
  | 
9 | 11 | import tests.auth.test_credentials  | 
10 | 12 | import tests.oauth2_token_exchange  | 
@@ -112,3 +114,152 @@ def serve(s):  | 
112 | 114 |     except Exception:  | 
113 | 115 |         os.remove(cfg_file_name)  | 
114 | 116 |         raise  | 
 | 117 | + | 
 | 118 | + | 
 | 119 | +@pytest.mark.asyncio  | 
 | 120 | +async def test_token_lazy_refresh():  | 
 | 121 | +    credentials = ServiceAccountCredentialsForTest(  | 
 | 122 | +        tests.auth.test_credentials.SERVICE_ACCOUNT_ID,  | 
 | 123 | +        tests.auth.test_credentials.ACCESS_KEY_ID,  | 
 | 124 | +        tests.auth.test_credentials.PRIVATE_KEY,  | 
 | 125 | +        "localhost:0",  | 
 | 126 | +    )  | 
 | 127 | + | 
 | 128 | +    credentials._tp.submit = MagicMock()  | 
 | 129 | + | 
 | 130 | +    mock_response = {"access_token": "token_v1", "expires_in": 3600}  | 
 | 131 | +    credentials._make_token_request = AsyncMock(return_value=mock_response)  | 
 | 132 | + | 
 | 133 | +    with patch("time.time") as mock_time:  | 
 | 134 | +        mock_time.return_value = 1000  | 
 | 135 | + | 
 | 136 | +        token1 = await credentials.token()  | 
 | 137 | +        assert token1 == "token_v1"  | 
 | 138 | +        assert credentials._make_token_request.call_count == 1  | 
 | 139 | + | 
 | 140 | +        token2 = await credentials.token()  | 
 | 141 | +        assert token2 == "token_v1"  | 
 | 142 | +        assert credentials._make_token_request.call_count == 1  | 
 | 143 | + | 
 | 144 | +        mock_time.return_value = 1000 + 3600 - 30 + 1  | 
 | 145 | +        credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}  | 
 | 146 | + | 
 | 147 | +        token3 = await credentials.token()  | 
 | 148 | +        assert token3 == "token_v2"  | 
 | 149 | +        assert credentials._make_token_request.call_count == 2  | 
 | 150 | + | 
 | 151 | + | 
 | 152 | +@pytest.mark.asyncio  | 
 | 153 | +async def test_token_double_check_locking():  | 
 | 154 | +    credentials = ServiceAccountCredentialsForTest(  | 
 | 155 | +        tests.auth.test_credentials.SERVICE_ACCOUNT_ID,  | 
 | 156 | +        tests.auth.test_credentials.ACCESS_KEY_ID,  | 
 | 157 | +        tests.auth.test_credentials.PRIVATE_KEY,  | 
 | 158 | +        "localhost:0",  | 
 | 159 | +    )  | 
 | 160 | + | 
 | 161 | +    credentials._tp.submit = MagicMock()  | 
 | 162 | + | 
 | 163 | +    call_count = 0  | 
 | 164 | + | 
 | 165 | +    async def mock_make_request():  | 
 | 166 | +        nonlocal call_count  | 
 | 167 | +        call_count += 1  | 
 | 168 | +        await asyncio.sleep(0.01)  | 
 | 169 | +        return {"access_token": f"token_v{call_count}", "expires_in": 3600}  | 
 | 170 | + | 
 | 171 | +    credentials._make_token_request = mock_make_request  | 
 | 172 | + | 
 | 173 | +    with patch("time.time") as mock_time:  | 
 | 174 | +        mock_time.return_value = 1000  | 
 | 175 | + | 
 | 176 | +        tasks = [credentials.token() for _ in range(10)]  | 
 | 177 | +        results = await asyncio.gather(*tasks)  | 
 | 178 | + | 
 | 179 | +        assert len(set(results)) == 1  | 
 | 180 | +        assert call_count == 1  | 
 | 181 | + | 
 | 182 | + | 
 | 183 | +@pytest.mark.asyncio  | 
 | 184 | +async def test_token_expiration_calculation():  | 
 | 185 | +    credentials = ServiceAccountCredentialsForTest(  | 
 | 186 | +        tests.auth.test_credentials.SERVICE_ACCOUNT_ID,  | 
 | 187 | +        tests.auth.test_credentials.ACCESS_KEY_ID,  | 
 | 188 | +        tests.auth.test_credentials.PRIVATE_KEY,  | 
 | 189 | +        "localhost:0",  | 
 | 190 | +    )  | 
 | 191 | + | 
 | 192 | +    credentials._tp.submit = MagicMock()  | 
 | 193 | + | 
 | 194 | +    with patch("time.time") as mock_time:  | 
 | 195 | +        mock_time.return_value = 1000  | 
 | 196 | + | 
 | 197 | +        credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600})  | 
 | 198 | + | 
 | 199 | +        await credentials.token()  | 
 | 200 | + | 
 | 201 | +        expected_expires = 1000 + 3600 - 30  | 
 | 202 | +        assert credentials._expires_in == expected_expires  | 
 | 203 | + | 
 | 204 | + | 
 | 205 | +@pytest.mark.asyncio  | 
 | 206 | +async def test_token_refresh_error_handling():  | 
 | 207 | +    credentials = ServiceAccountCredentialsForTest(  | 
 | 208 | +        tests.auth.test_credentials.SERVICE_ACCOUNT_ID,  | 
 | 209 | +        tests.auth.test_credentials.ACCESS_KEY_ID,  | 
 | 210 | +        tests.auth.test_credentials.PRIVATE_KEY,  | 
 | 211 | +        "localhost:0",  | 
 | 212 | +    )  | 
 | 213 | + | 
 | 214 | +    credentials._tp.submit = MagicMock()  | 
 | 215 | + | 
 | 216 | +    credentials._make_token_request = AsyncMock(side_effect=Exception("Network error"))  | 
 | 217 | + | 
 | 218 | +    with pytest.raises(Exception) as exc_info:  | 
 | 219 | +        await credentials.token()  | 
 | 220 | + | 
 | 221 | +    assert "Network error" in str(exc_info.value)  | 
 | 222 | +    assert credentials.last_error == "Network error"  | 
 | 223 | + | 
 | 224 | + | 
 | 225 | +@pytest.mark.asyncio  | 
 | 226 | +async def test_hybrid_background_and_sync_refresh():  | 
 | 227 | +    credentials = ServiceAccountCredentialsForTest(  | 
 | 228 | +        tests.auth.test_credentials.SERVICE_ACCOUNT_ID,  | 
 | 229 | +        tests.auth.test_credentials.ACCESS_KEY_ID,  | 
 | 230 | +        tests.auth.test_credentials.PRIVATE_KEY,  | 
 | 231 | +        "localhost:0",  | 
 | 232 | +    )  | 
 | 233 | + | 
 | 234 | +    call_count = 0  | 
 | 235 | +    background_calls = []  | 
 | 236 | + | 
 | 237 | +    async def mock_make_request():  | 
 | 238 | +        nonlocal call_count  | 
 | 239 | +        call_count += 1  | 
 | 240 | +        return {"access_token": f"token_v{call_count}", "expires_in": 3600}  | 
 | 241 | + | 
 | 242 | +    def mock_submit(callback):  | 
 | 243 | +        background_calls.append(callback)  | 
 | 244 | + | 
 | 245 | +    credentials._make_token_request = mock_make_request  | 
 | 246 | +    credentials._tp.submit = mock_submit  | 
 | 247 | + | 
 | 248 | +    with patch("time.time") as mock_time:  | 
 | 249 | +        mock_time.return_value = 1000  | 
 | 250 | + | 
 | 251 | +        token1 = await credentials.token()  | 
 | 252 | +        assert token1 == "token_v1"  | 
 | 253 | +        assert call_count == 1  | 
 | 254 | +        assert len(background_calls) == 0  | 
 | 255 | + | 
 | 256 | +        mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1  | 
 | 257 | +        token2 = await credentials.token()  | 
 | 258 | +        assert token2 == "token_v1"  | 
 | 259 | +        assert call_count == 1  | 
 | 260 | +        assert len(background_calls) == 1  | 
 | 261 | + | 
 | 262 | +        mock_time.return_value = 1000 + 3600 - 30 + 1  | 
 | 263 | +        token3 = await credentials.token()  | 
 | 264 | +        assert token3 == "token_v2"  | 
 | 265 | +        assert call_count == 2  | 
0 commit comments