diff --git a/test/unit/aio/test_proxies_async.py b/test/unit/aio/test_proxies_async.py index 23c73e4f8..e41f73e65 100644 --- a/test/unit/aio/test_proxies_async.py +++ b/test/unit/aio/test_proxies_async.py @@ -85,24 +85,11 @@ async def test_basic_query_through_proxy_async( finally: await conn.close() - async with aiohttp.ClientSession() as session: - async with session.get( - f"{proxy_wm.http_host_with_port}/__admin/requests" - ) as resp: - proxy_reqs = await resp.json() - assert any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy_reqs["requests"] - ) + # Ensure proxy saw query + assert proxy_wm.saw_urls_matching(["/queries/v1/query-request"]) - async with session.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ) as resp: - target_reqs = await resp.json() - assert any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) + # Ensure backend saw query + assert target_wm.saw_urls_matching(["/queries/v1/query-request"]) @pytest.mark.skipolddriver @@ -165,49 +152,37 @@ async def test_large_query_through_proxy_async( async def _execute_large_query(connect_kwargs, row_count: int): + """Execute a large query using connection kwargs. + + Creates a connection, executes the large query, and validates it uses multiple batches. + """ conn = await async_connect(**connect_kwargs) try: cur = conn.cursor() - await cur.execute( - f"select seq4() as n from table(generator(rowcount => {row_count}));" - ) + await _execute_large_query_on_cursor(cur, row_count) + # Verify that the query used multiple batches (remote storage) assert len(cur._result_set.batches) > 1 - _ = [r async for r in cur] finally: await conn.close() -async def _collect_request_flags(proxy_wm, target_wm, storage_wm) -> RequestFlags: - async with aiohttp.ClientSession() as session: - async with session.get( - f"{proxy_wm.http_host_with_port}/__admin/requests" - ) as resp: - proxy_reqs = await resp.json() - async with session.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ) as resp: - target_reqs = await resp.json() - async with session.get( - f"{storage_wm.http_host_with_port}/__admin/requests" - ) as resp: - storage_reqs = await resp.json() - - proxy_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy_reqs["requests"] +async def _execute_large_query_on_cursor(cursor, row_count: int = 100000): + await cursor.execute( + f"SELECT seq4() as n FROM TABLE(GENERATOR(ROWCOUNT => {row_count}))" ) - target_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) - proxy_saw_storage = any( - "/amazonaws/test/s3testaccount/stage/results/" in r["request"]["url"] - for r in proxy_reqs["requests"] + return [r async for r in cursor] + + +async def _collect_request_flags(proxy_wm, target_wm, storage_wm) -> RequestFlags: + proxy_saw_db = proxy_wm.saw_urls_matching(["/queries/v1/query-request"]) + target_saw_db = target_wm.saw_urls_matching(["/queries/v1/query-request"]) + proxy_saw_storage = proxy_wm.saw_urls_matching( + ["/amazonaws/test/s3testaccount/stage/results/"] ) - storage_saw_storage = any( - "/amazonaws/test/s3testaccount/stage/results/" in r["request"]["url"] - for r in storage_reqs["requests"] + storage_saw_storage = storage_wm.saw_urls_matching( + ["/amazonaws/test/s3testaccount/stage/results/"] ) + return RequestFlags( proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db, @@ -217,56 +192,22 @@ async def _collect_request_flags(proxy_wm, target_wm, storage_wm) -> RequestFlag async def _collect_db_request_flags_only(proxy_wm, target_wm) -> DbRequestFlags: - async with aiohttp.ClientSession() as session: - async with session.get( - f"{proxy_wm.http_host_with_port}/__admin/requests" - ) as resp: - proxy_reqs = await resp.json() - async with session.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ) as resp: - target_reqs = await resp.json() - proxy_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy_reqs["requests"] - ) - target_saw_db = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) + proxy_saw_db = proxy_wm.saw_urls_matching(["/queries/v1/query-request"]) + target_saw_db = target_wm.saw_urls_matching(["/queries/v1/query-request"]) return DbRequestFlags(proxy_saw_db=proxy_saw_db, target_saw_db=target_saw_db) async def _collect_proxy_precedence_flags( proxy1_wm, proxy2_wm, target_wm ) -> ProxyPrecedenceFlags: - """Async version of proxy precedence flags collection using aiohttp.""" - async with aiohttp.ClientSession() as session: - async with session.get( - f"{proxy1_wm.http_host_with_port}/__admin/requests" - ) as resp: - proxy1_reqs = await resp.json() - async with session.get( - f"{proxy2_wm.http_host_with_port}/__admin/requests" - ) as resp: - proxy2_reqs = await resp.json() - async with session.get( - f"{target_wm.http_host_with_port}/__admin/requests" - ) as resp: - target_reqs = await resp.json() - - proxy1_saw_request = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy1_reqs["requests"] - ) - proxy2_saw_request = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in proxy2_reqs["requests"] - ) - backend_saw_request = any( - "/queries/v1/query-request" in r["request"]["url"] - for r in target_reqs["requests"] - ) + """Collect flags for proxy precedence tests. + + Checks which proxy (or target) saw query requests, useful for verifying + that connection parameters take precedence over environment variables. + """ + proxy1_saw_request = proxy1_wm.saw_urls_matching(["/queries/v1/query-request"]) + proxy2_saw_request = proxy2_wm.saw_urls_matching(["/queries/v1/query-request"]) + backend_saw_request = target_wm.saw_urls_matching(["/queries/v1/query-request"]) return ProxyPrecedenceFlags( proxy1_saw_request=proxy1_saw_request,