Skip to content

Commit 01ada93

Browse files
[async] Applied #2568 session manager implementation - definitions, usages, check_no_native_http.py
1 parent 17fbc24 commit 01ada93

File tree

9 files changed

+317
-53
lines changed

9 files changed

+317
-53
lines changed

ci/pre-commit/check_no_native_http.py

Lines changed: 173 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
"""
3-
Pre-commit hook to prevent direct usage of requests and urllib3 calls.
3+
Pre-commit hook to prevent direct usage of requests, urllib3, and aiohttp calls.
44
Ensures all HTTP requests go through SessionManager.
55
"""
66
import argparse
@@ -24,6 +24,9 @@ class ViolationType(Enum):
2424
DIRECT_SESSION_IMPORT = "SNOW008"
2525
STAR_IMPORT = "SNOW010"
2626
URLLIB3_DIRECT_API = "SNOW011"
27+
AIOHTTP_CLIENT_SESSION = "SNOW012"
28+
AIOHTTP_REQUEST = "SNOW013"
29+
DIRECT_AIOHTTP_IMPORT = "SNOW014"
2730

2831

2932
@dataclass(frozen=True)
@@ -57,6 +60,7 @@ class ModulePattern:
5760
# Core module names
5861
REQUESTS_MODULES = {"requests"}
5962
URLLIB3_MODULES = {"urllib3"}
63+
AIOHTTP_MODULES = {"aiohttp"}
6064

6165
# HTTP-related symbols
6266
HTTP_METHODS = {
@@ -71,6 +75,8 @@ class ModulePattern:
7175
}
7276
POOL_MANAGERS = {"PoolManager", "ProxyManager"}
7377
URLLIB3_APIS = {"request", "urlopen", "HTTPConnectionPool", "HTTPSConnectionPool"}
78+
AIOHTTP_SESSIONS = {"ClientSession"}
79+
AIOHTTP_APIS = {"request"}
7480

7581
@classmethod
7682
def is_requests_module(cls, module_or_symbol: str) -> bool:
@@ -112,6 +118,22 @@ def is_urllib3_module(cls, module_or_symbol: str) -> bool:
112118

113119
return False
114120

121+
@classmethod
122+
def is_aiohttp_module(cls, module_or_symbol: str) -> bool:
123+
"""Check if module or symbol is aiohttp-related."""
124+
if not module_or_symbol:
125+
return False
126+
127+
# Exact match
128+
if module_or_symbol in cls.AIOHTTP_MODULES:
129+
return True
130+
131+
# Dotted path ending in .aiohttp
132+
if module_or_symbol.endswith(".aiohttp"):
133+
return True
134+
135+
return False
136+
115137
@classmethod
116138
def is_http_method(cls, name: str) -> bool:
117139
"""Check if name is an HTTP method."""
@@ -127,6 +149,16 @@ def is_urllib3_api(cls, name: str) -> bool:
127149
"""Check if name is a urllib3 API function."""
128150
return name in cls.URLLIB3_APIS
129151

152+
@classmethod
153+
def is_aiohttp_session(cls, name: str) -> bool:
154+
"""Check if name is an aiohttp session class."""
155+
return name in cls.AIOHTTP_SESSIONS
156+
157+
@classmethod
158+
def is_aiohttp_api(cls, name: str) -> bool:
159+
"""Check if name is an aiohttp API function."""
160+
return name in cls.AIOHTTP_APIS
161+
130162

131163
class ImportContext:
132164
"""Tracks all import-related information."""
@@ -234,6 +266,29 @@ def is_urllib3_related(self, name: str) -> bool:
234266

235267
return False
236268

269+
def is_aiohttp_related(self, name: str) -> bool:
270+
"""Check if name refers to aiohttp module or its components."""
271+
resolved_name = self.resolve_name(name)
272+
273+
# Direct aiohttp module
274+
if resolved_name == "aiohttp":
275+
return True
276+
277+
# Check import info
278+
if resolved_name in self.imports:
279+
import_info = self.imports[resolved_name]
280+
return ModulePattern.is_aiohttp_module(import_info.module) or (
281+
import_info.imported_name
282+
and ModulePattern.is_aiohttp_module(import_info.imported_name)
283+
)
284+
285+
# Check star imports
286+
for module in self.star_imports:
287+
if ModulePattern.is_aiohttp_module(module):
288+
return True
289+
290+
return False
291+
237292
def is_runtime(self, name: str) -> bool:
238293
"""Check if name is used at runtime (has actual runtime usage)."""
239294
return (
@@ -380,10 +435,12 @@ def visit_Assign(self, node: ast.Assign):
380435
else:
381436
# Handle v = snowflake.connector.vendored.requests
382437
full_path = ".".join(dotted_chain)
383-
# Check if this points to a requests or urllib3 module
384-
if ModulePattern.is_requests_module(
385-
full_path
386-
) or ModulePattern.is_urllib3_module(full_path):
438+
# Check if this points to a requests, urllib3, or aiohttp module
439+
if (
440+
ModulePattern.is_requests_module(full_path)
441+
or ModulePattern.is_urllib3_module(full_path)
442+
or ModulePattern.is_aiohttp_module(full_path)
443+
):
387444
self.context.add_variable_alias(var_name, full_path)
388445

389446
# Handle attribute assignments: self.attr = value
@@ -484,7 +541,7 @@ def _extract_from_string_annotation(self, annotation_str: str):
484541
# Match Python identifiers that could be type names
485542
names = re.findall(r"\b([A-Z][a-zA-Z0-9_]*)\b", annotation_str)
486543
for name in names:
487-
if name in ["Session", "PoolManager", "ProxyManager"]:
544+
if name in ["Session", "PoolManager", "ProxyManager", "ClientSession"]:
488545
self.context.add_type_hint_usage(name)
489546

490547
def _extract_from_subscript(self, node: ast.Subscript):
@@ -535,9 +592,11 @@ def analyze_calls(self, tree: ast.AST):
535592
def analyze_star_imports(self):
536593
"""Analyze star import violations."""
537594
for module in self.context.star_imports:
538-
if ModulePattern.is_requests_module(
539-
module
540-
) or ModulePattern.is_urllib3_module(module):
595+
if (
596+
ModulePattern.is_requests_module(module)
597+
or ModulePattern.is_urllib3_module(module)
598+
or ModulePattern.is_aiohttp_module(module)
599+
):
541600
self.violations.append(
542601
HTTPViolation(
543602
self.filename,
@@ -552,7 +611,7 @@ def _check_import_violation(self, import_info: ImportInfo) -> List[HTTPViolation
552611
"""Check a single import for violations."""
553612
violations = []
554613

555-
# Always flag HTTP method imports
614+
# Always flag HTTP method imports from requests
556615
if (
557616
import_info.imported_name
558617
and ModulePattern.is_requests_module(import_info.module)
@@ -568,7 +627,7 @@ def _check_import_violation(self, import_info: ImportInfo) -> List[HTTPViolation
568627
)
569628
)
570629

571-
# Flag Session/PoolManager imports only if used at runtime
630+
# Flag Session/PoolManager/ClientSession imports only if used at runtime
572631
if import_info.imported_name and self.context.is_runtime(
573632
import_info.alias_name
574633
):
@@ -600,6 +659,19 @@ def _check_import_violation(self, import_info: ImportInfo) -> List[HTTPViolation
600659
)
601660
)
602661

662+
elif ModulePattern.is_aiohttp_module(
663+
import_info.module
664+
) and ModulePattern.is_aiohttp_session(import_info.imported_name):
665+
violations.append(
666+
HTTPViolation(
667+
self.filename,
668+
import_info.line,
669+
import_info.col,
670+
ViolationType.DIRECT_AIOHTTP_IMPORT,
671+
f"Direct import of {import_info.imported_name} from aiohttp for runtime use is forbidden, use SessionManager instead",
672+
)
673+
)
674+
603675
return violations
604676

605677

@@ -671,7 +743,7 @@ def _check_direct_call(self, node: ast.Call) -> Optional[HTTPViolation]:
671743
f"Direct use of imported {import_info.imported_name}() is forbidden, use SessionManager instead",
672744
)
673745

674-
# Session/PoolManager instantiation
746+
# Session/PoolManager/ClientSession instantiation
675747
if (
676748
import_info.imported_name == "Session"
677749
and ModulePattern.is_requests_module(import_info.module)
@@ -697,6 +769,19 @@ def _check_direct_call(self, node: ast.Call) -> Optional[HTTPViolation]:
697769
f"Direct use of imported {import_info.imported_name}() is forbidden, use SessionManager instead",
698770
)
699771

772+
if (
773+
import_info.imported_name
774+
and ModulePattern.is_aiohttp_session(import_info.imported_name)
775+
and ModulePattern.is_aiohttp_module(import_info.module)
776+
):
777+
return HTTPViolation(
778+
self.filename,
779+
node.lineno,
780+
node.col_offset,
781+
ViolationType.AIOHTTP_CLIENT_SESSION,
782+
f"Direct use of imported {import_info.imported_name}() is forbidden, use SessionManager instead",
783+
)
784+
700785
# Check star imports
701786
for module in self.context.star_imports:
702787
if ModulePattern.is_requests_module(
@@ -719,7 +804,7 @@ def _is_chained_call(self, node: ast.Call) -> bool:
719804
)
720805

721806
def _check_chained_calls(self, node: ast.Call) -> Optional[HTTPViolation]:
722-
"""Check for chained calls like requests.Session().get() or urllib3.PoolManager().request()."""
807+
"""Check for chained calls like requests.Session().get(), urllib3.PoolManager().request(), or aiohttp.ClientSession().get()."""
723808
if isinstance(node.func, ast.Attribute) and isinstance(
724809
node.func.value, ast.Call
725810
):
@@ -762,6 +847,23 @@ def _check_chained_calls(self, node: ast.Call) -> Optional[HTTPViolation]:
762847
f"Chained call urllib3.{inner_func}().{outer_method}() is forbidden, use SessionManager instead",
763848
)
764849

850+
# Check for aiohttp.ClientSession().method()
851+
if (
852+
(
853+
inner_module == "aiohttp"
854+
or self.context.is_aiohttp_related(inner_module)
855+
)
856+
and ModulePattern.is_aiohttp_session(inner_func)
857+
and ModulePattern.is_http_method(outer_method)
858+
):
859+
return HTTPViolation(
860+
self.filename,
861+
node.lineno,
862+
node.col_offset,
863+
ViolationType.AIOHTTP_CLIENT_SESSION,
864+
f"Chained call aiohttp.{inner_func}().{outer_method}() is forbidden, use SessionManager instead",
865+
)
866+
765867
return None
766868

767869
def _check_two_part_call(
@@ -780,6 +882,10 @@ def _check_two_part_call(
780882
resolved_module
781883
):
782884
return self._check_urllib3_call(node, func_name)
885+
elif module_name == "aiohttp" or self.context.is_aiohttp_related(
886+
resolved_module
887+
):
888+
return self._check_aiohttp_call(node, func_name)
783889

784890
# Check for aliased module calls (e.g., v = vendored.requests; v.get())
785891
if module_name in self.context.variable_aliases:
@@ -788,13 +894,15 @@ def _check_two_part_call(
788894
return self._check_requests_call(node, func_name)
789895
elif ModulePattern.is_urllib3_module(aliased_module):
790896
return self._check_urllib3_call(node, func_name)
897+
elif ModulePattern.is_aiohttp_module(aliased_module):
898+
return self._check_aiohttp_call(node, func_name)
791899

792900
return None
793901

794902
def _check_multi_part_call(
795903
self, node: ast.Call, chain: List[str]
796904
) -> Optional[HTTPViolation]:
797-
"""Check multi-part calls like requests.sessions.Session or self.req_lib.get."""
905+
"""Check multi-part calls like requests.sessions.Session, aiohttp.client.ClientSession or self.req_lib.get."""
798906
if len(chain) >= 3:
799907
module_name = chain[0]
800908

@@ -820,6 +928,20 @@ def _check_multi_part_call(
820928
f"Direct use of {'.'.join(chain)}() is forbidden, use SessionManager instead",
821929
)
822930

931+
elif module_name == "aiohttp" or self.context.is_aiohttp_related(
932+
module_name
933+
):
934+
# aiohttp.client.ClientSession, etc.
935+
func_name = chain[-1]
936+
if ModulePattern.is_aiohttp_session(func_name):
937+
return HTTPViolation(
938+
self.filename,
939+
node.lineno,
940+
node.col_offset,
941+
ViolationType.AIOHTTP_CLIENT_SESSION,
942+
f"Direct use of {'.'.join(chain)}() is forbidden, use SessionManager instead",
943+
)
944+
823945
# Check for aliased calls like self.req_lib.get() where req_lib is an alias
824946
elif len(chain) >= 3:
825947
# For patterns like self.req_lib.get(), check if req_lib is an alias
@@ -848,6 +970,16 @@ def _check_multi_part_call(
848970
ViolationType.URLLIB3_POOLMANAGER,
849971
f"Direct use of aliased {chain[0]}.{potential_alias}.{func_name}() is forbidden, use SessionManager instead",
850972
)
973+
elif ModulePattern.is_aiohttp_module(
974+
aliased_module
975+
) and ModulePattern.is_aiohttp_session(func_name):
976+
return HTTPViolation(
977+
self.filename,
978+
node.lineno,
979+
node.col_offset,
980+
ViolationType.AIOHTTP_CLIENT_SESSION,
981+
f"Direct use of aliased {chain[0]}.{potential_alias}.{func_name}() is forbidden, use SessionManager instead",
982+
)
851983

852984
return None
853985

@@ -903,12 +1035,35 @@ def _check_urllib3_call(
9031035
)
9041036
return None
9051037

1038+
def _check_aiohttp_call(
1039+
self, node: ast.Call, func_name: str
1040+
) -> Optional[HTTPViolation]:
1041+
"""Check aiohttp module calls."""
1042+
if ModulePattern.is_aiohttp_session(func_name):
1043+
return HTTPViolation(
1044+
self.filename,
1045+
node.lineno,
1046+
node.col_offset,
1047+
ViolationType.AIOHTTP_CLIENT_SESSION,
1048+
f"Direct use of aiohttp.{func_name}() is forbidden, use SessionManager instead",
1049+
)
1050+
elif ModulePattern.is_aiohttp_api(func_name):
1051+
return HTTPViolation(
1052+
self.filename,
1053+
node.lineno,
1054+
node.col_offset,
1055+
ViolationType.AIOHTTP_REQUEST,
1056+
f"Direct use of aiohttp.{func_name}() is forbidden, use SessionManager instead",
1057+
)
1058+
return None
1059+
9061060

9071061
class FileChecker:
9081062
"""Handles file-level checking logic with proper glob path matching."""
9091063

9101064
EXEMPT_PATTERNS = [
9111065
"**/session_manager.py",
1066+
"**/_session_manager.py",
9121067
"**/vendored/**/*",
9131068
]
9141069

@@ -1043,8 +1198,11 @@ def main():
10431198
print(
10441199
" - Replace urllib3.PoolManager/ProxyManager() with session from session_manager.use_session()"
10451200
)
1201+
print(
1202+
" - Replace aiohttp.ClientSession() with async SessionManager.use_session()"
1203+
)
10461204
print(" - Replace direct HTTP method imports with SessionManager usage")
1047-
print(" - Use SessionManager for all HTTP operations")
1205+
print(" - Use SessionManager for all HTTP operations (sync and async)")
10481206

10491207
print()
10501208
print(f"Found {len(all_violations)} violation(s)")

src/snowflake/connector/aio/_network.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def add_retry_params(self, full_url: str) -> str:
567567
include_retry_reason = self._connection._enable_retry_reason_in_query_response
568568
include_retry_params = kwargs.pop("_include_retry_params", False)
569569

570-
async with self._use_requests_session(full_url) as session:
570+
async with self._use_session(full_url) as session:
571571
retry_ctx = RetryCtx(
572572
_include_retry_params=include_retry_params,
573573
_include_retry_reason=include_retry_reason,
@@ -848,12 +848,9 @@ async def _request_exec(
848848
errno=ER_FAILED_TO_REQUEST,
849849
) from err
850850

851-
def make_requests_session(self) -> aiohttp.ClientSession:
852-
return self._session_manager.make_session()
853-
854851
@contextlib.asynccontextmanager
855-
async def _use_requests_session(
852+
async def _use_session(
856853
self, url: str | None = None
857854
) -> AsyncGenerator[aiohttp.ClientSession]:
858-
async with self._session_manager.use_requests_session(url) as session:
855+
async with self._session_manager.use_session(url) as session:
859856
yield session

src/snowflake/connector/aio/_result_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ async def download_chunk(http_session):
236236
)
237237
# Try to reuse a connection if possible
238238
if connection and connection._rest is not None:
239-
async with connection._rest._use_requests_session() as session:
239+
async with connection._rest._use_session() as session:
240240
logger.debug(
241241
f"downloading result batch id: {self.id} with existing session {session}"
242242
)

0 commit comments

Comments
 (0)