11#!/usr/bin/env python3
22"""
3- Pre-commit hook to prevent direct usage of requests and urllib3 calls.
3+ Pre-commit hook to prevent direct usage of requests, urllib3, and aiohttp calls.
44Ensures all HTTP requests go through SessionManager.
55"""
66import argparse
@@ -24,6 +24,9 @@ class ViolationType(Enum):
2424 DIRECT_SESSION_IMPORT = "SNOW008"
2525 STAR_IMPORT = "SNOW010"
2626 URLLIB3_DIRECT_API = "SNOW011"
27+ AIOHTTP_CLIENT_SESSION = "SNOW012"
28+ AIOHTTP_REQUEST = "SNOW013"
29+ DIRECT_AIOHTTP_IMPORT = "SNOW014"
2730
2831
2932@dataclass (frozen = True )
@@ -57,6 +60,7 @@ class ModulePattern:
5760 # Core module names
5861 REQUESTS_MODULES = {"requests" }
5962 URLLIB3_MODULES = {"urllib3" }
63+ AIOHTTP_MODULES = {"aiohttp" }
6064
6165 # HTTP-related symbols
6266 HTTP_METHODS = {
@@ -71,6 +75,8 @@ class ModulePattern:
7175 }
7276 POOL_MANAGERS = {"PoolManager" , "ProxyManager" }
7377 URLLIB3_APIS = {"request" , "urlopen" , "HTTPConnectionPool" , "HTTPSConnectionPool" }
78+ AIOHTTP_SESSIONS = {"ClientSession" }
79+ AIOHTTP_APIS = {"request" }
7480
7581 @classmethod
7682 def is_requests_module (cls , module_or_symbol : str ) -> bool :
@@ -112,6 +118,22 @@ def is_urllib3_module(cls, module_or_symbol: str) -> bool:
112118
113119 return False
114120
121+ @classmethod
122+ def is_aiohttp_module (cls , module_or_symbol : str ) -> bool :
123+ """Check if module or symbol is aiohttp-related."""
124+ if not module_or_symbol :
125+ return False
126+
127+ # Exact match
128+ if module_or_symbol in cls .AIOHTTP_MODULES :
129+ return True
130+
131+ # Dotted path ending in .aiohttp
132+ if module_or_symbol .endswith (".aiohttp" ):
133+ return True
134+
135+ return False
136+
115137 @classmethod
116138 def is_http_method (cls , name : str ) -> bool :
117139 """Check if name is an HTTP method."""
@@ -127,6 +149,16 @@ def is_urllib3_api(cls, name: str) -> bool:
127149 """Check if name is a urllib3 API function."""
128150 return name in cls .URLLIB3_APIS
129151
152+ @classmethod
153+ def is_aiohttp_session (cls , name : str ) -> bool :
154+ """Check if name is an aiohttp session class."""
155+ return name in cls .AIOHTTP_SESSIONS
156+
157+ @classmethod
158+ def is_aiohttp_api (cls , name : str ) -> bool :
159+ """Check if name is an aiohttp API function."""
160+ return name in cls .AIOHTTP_APIS
161+
130162
131163class ImportContext :
132164 """Tracks all import-related information."""
@@ -234,6 +266,29 @@ def is_urllib3_related(self, name: str) -> bool:
234266
235267 return False
236268
269+ def is_aiohttp_related (self , name : str ) -> bool :
270+ """Check if name refers to aiohttp module or its components."""
271+ resolved_name = self .resolve_name (name )
272+
273+ # Direct aiohttp module
274+ if resolved_name == "aiohttp" :
275+ return True
276+
277+ # Check import info
278+ if resolved_name in self .imports :
279+ import_info = self .imports [resolved_name ]
280+ return ModulePattern .is_aiohttp_module (import_info .module ) or (
281+ import_info .imported_name
282+ and ModulePattern .is_aiohttp_module (import_info .imported_name )
283+ )
284+
285+ # Check star imports
286+ for module in self .star_imports :
287+ if ModulePattern .is_aiohttp_module (module ):
288+ return True
289+
290+ return False
291+
237292 def is_runtime (self , name : str ) -> bool :
238293 """Check if name is used at runtime (has actual runtime usage)."""
239294 return (
@@ -380,10 +435,12 @@ def visit_Assign(self, node: ast.Assign):
380435 else :
381436 # Handle v = snowflake.connector.vendored.requests
382437 full_path = "." .join (dotted_chain )
383- # Check if this points to a requests or urllib3 module
384- if ModulePattern .is_requests_module (
385- full_path
386- ) or ModulePattern .is_urllib3_module (full_path ):
438+ # Check if this points to a requests, urllib3, or aiohttp module
439+ if (
440+ ModulePattern .is_requests_module (full_path )
441+ or ModulePattern .is_urllib3_module (full_path )
442+ or ModulePattern .is_aiohttp_module (full_path )
443+ ):
387444 self .context .add_variable_alias (var_name , full_path )
388445
389446 # Handle attribute assignments: self.attr = value
@@ -484,7 +541,7 @@ def _extract_from_string_annotation(self, annotation_str: str):
484541 # Match Python identifiers that could be type names
485542 names = re .findall (r"\b([A-Z][a-zA-Z0-9_]*)\b" , annotation_str )
486543 for name in names :
487- if name in ["Session" , "PoolManager" , "ProxyManager" ]:
544+ if name in ["Session" , "PoolManager" , "ProxyManager" , "ClientSession" ]:
488545 self .context .add_type_hint_usage (name )
489546
490547 def _extract_from_subscript (self , node : ast .Subscript ):
@@ -535,9 +592,11 @@ def analyze_calls(self, tree: ast.AST):
535592 def analyze_star_imports (self ):
536593 """Analyze star import violations."""
537594 for module in self .context .star_imports :
538- if ModulePattern .is_requests_module (
539- module
540- ) or ModulePattern .is_urllib3_module (module ):
595+ if (
596+ ModulePattern .is_requests_module (module )
597+ or ModulePattern .is_urllib3_module (module )
598+ or ModulePattern .is_aiohttp_module (module )
599+ ):
541600 self .violations .append (
542601 HTTPViolation (
543602 self .filename ,
@@ -552,7 +611,7 @@ def _check_import_violation(self, import_info: ImportInfo) -> List[HTTPViolation
552611 """Check a single import for violations."""
553612 violations = []
554613
555- # Always flag HTTP method imports
614+ # Always flag HTTP method imports from requests
556615 if (
557616 import_info .imported_name
558617 and ModulePattern .is_requests_module (import_info .module )
@@ -568,7 +627,7 @@ def _check_import_violation(self, import_info: ImportInfo) -> List[HTTPViolation
568627 )
569628 )
570629
571- # Flag Session/PoolManager imports only if used at runtime
630+ # Flag Session/PoolManager/ClientSession imports only if used at runtime
572631 if import_info .imported_name and self .context .is_runtime (
573632 import_info .alias_name
574633 ):
@@ -600,6 +659,19 @@ def _check_import_violation(self, import_info: ImportInfo) -> List[HTTPViolation
600659 )
601660 )
602661
662+ elif ModulePattern .is_aiohttp_module (
663+ import_info .module
664+ ) and ModulePattern .is_aiohttp_session (import_info .imported_name ):
665+ violations .append (
666+ HTTPViolation (
667+ self .filename ,
668+ import_info .line ,
669+ import_info .col ,
670+ ViolationType .DIRECT_AIOHTTP_IMPORT ,
671+ f"Direct import of { import_info .imported_name } from aiohttp for runtime use is forbidden, use SessionManager instead" ,
672+ )
673+ )
674+
603675 return violations
604676
605677
@@ -671,7 +743,7 @@ def _check_direct_call(self, node: ast.Call) -> Optional[HTTPViolation]:
671743 f"Direct use of imported { import_info .imported_name } () is forbidden, use SessionManager instead" ,
672744 )
673745
674- # Session/PoolManager instantiation
746+ # Session/PoolManager/ClientSession instantiation
675747 if (
676748 import_info .imported_name == "Session"
677749 and ModulePattern .is_requests_module (import_info .module )
@@ -697,6 +769,19 @@ def _check_direct_call(self, node: ast.Call) -> Optional[HTTPViolation]:
697769 f"Direct use of imported { import_info .imported_name } () is forbidden, use SessionManager instead" ,
698770 )
699771
772+ if (
773+ import_info .imported_name
774+ and ModulePattern .is_aiohttp_session (import_info .imported_name )
775+ and ModulePattern .is_aiohttp_module (import_info .module )
776+ ):
777+ return HTTPViolation (
778+ self .filename ,
779+ node .lineno ,
780+ node .col_offset ,
781+ ViolationType .AIOHTTP_CLIENT_SESSION ,
782+ f"Direct use of imported { import_info .imported_name } () is forbidden, use SessionManager instead" ,
783+ )
784+
700785 # Check star imports
701786 for module in self .context .star_imports :
702787 if ModulePattern .is_requests_module (
@@ -719,7 +804,7 @@ def _is_chained_call(self, node: ast.Call) -> bool:
719804 )
720805
721806 def _check_chained_calls (self , node : ast .Call ) -> Optional [HTTPViolation ]:
722- """Check for chained calls like requests.Session().get() or urllib3.PoolManager().request()."""
807+ """Check for chained calls like requests.Session().get(), urllib3.PoolManager().request(), or aiohttp.ClientSession().get ()."""
723808 if isinstance (node .func , ast .Attribute ) and isinstance (
724809 node .func .value , ast .Call
725810 ):
@@ -762,6 +847,23 @@ def _check_chained_calls(self, node: ast.Call) -> Optional[HTTPViolation]:
762847 f"Chained call urllib3.{ inner_func } ().{ outer_method } () is forbidden, use SessionManager instead" ,
763848 )
764849
850+ # Check for aiohttp.ClientSession().method()
851+ if (
852+ (
853+ inner_module == "aiohttp"
854+ or self .context .is_aiohttp_related (inner_module )
855+ )
856+ and ModulePattern .is_aiohttp_session (inner_func )
857+ and ModulePattern .is_http_method (outer_method )
858+ ):
859+ return HTTPViolation (
860+ self .filename ,
861+ node .lineno ,
862+ node .col_offset ,
863+ ViolationType .AIOHTTP_CLIENT_SESSION ,
864+ f"Chained call aiohttp.{ inner_func } ().{ outer_method } () is forbidden, use SessionManager instead" ,
865+ )
866+
765867 return None
766868
767869 def _check_two_part_call (
@@ -780,6 +882,10 @@ def _check_two_part_call(
780882 resolved_module
781883 ):
782884 return self ._check_urllib3_call (node , func_name )
885+ elif module_name == "aiohttp" or self .context .is_aiohttp_related (
886+ resolved_module
887+ ):
888+ return self ._check_aiohttp_call (node , func_name )
783889
784890 # Check for aliased module calls (e.g., v = vendored.requests; v.get())
785891 if module_name in self .context .variable_aliases :
@@ -788,13 +894,15 @@ def _check_two_part_call(
788894 return self ._check_requests_call (node , func_name )
789895 elif ModulePattern .is_urllib3_module (aliased_module ):
790896 return self ._check_urllib3_call (node , func_name )
897+ elif ModulePattern .is_aiohttp_module (aliased_module ):
898+ return self ._check_aiohttp_call (node , func_name )
791899
792900 return None
793901
794902 def _check_multi_part_call (
795903 self , node : ast .Call , chain : List [str ]
796904 ) -> Optional [HTTPViolation ]:
797- """Check multi-part calls like requests.sessions.Session or self.req_lib.get."""
905+ """Check multi-part calls like requests.sessions.Session, aiohttp.client.ClientSession or self.req_lib.get."""
798906 if len (chain ) >= 3 :
799907 module_name = chain [0 ]
800908
@@ -820,6 +928,20 @@ def _check_multi_part_call(
820928 f"Direct use of { '.' .join (chain )} () is forbidden, use SessionManager instead" ,
821929 )
822930
931+ elif module_name == "aiohttp" or self .context .is_aiohttp_related (
932+ module_name
933+ ):
934+ # aiohttp.client.ClientSession, etc.
935+ func_name = chain [- 1 ]
936+ if ModulePattern .is_aiohttp_session (func_name ):
937+ return HTTPViolation (
938+ self .filename ,
939+ node .lineno ,
940+ node .col_offset ,
941+ ViolationType .AIOHTTP_CLIENT_SESSION ,
942+ f"Direct use of { '.' .join (chain )} () is forbidden, use SessionManager instead" ,
943+ )
944+
823945 # Check for aliased calls like self.req_lib.get() where req_lib is an alias
824946 elif len (chain ) >= 3 :
825947 # For patterns like self.req_lib.get(), check if req_lib is an alias
@@ -848,6 +970,16 @@ def _check_multi_part_call(
848970 ViolationType .URLLIB3_POOLMANAGER ,
849971 f"Direct use of aliased { chain [0 ]} .{ potential_alias } .{ func_name } () is forbidden, use SessionManager instead" ,
850972 )
973+ elif ModulePattern .is_aiohttp_module (
974+ aliased_module
975+ ) and ModulePattern .is_aiohttp_session (func_name ):
976+ return HTTPViolation (
977+ self .filename ,
978+ node .lineno ,
979+ node .col_offset ,
980+ ViolationType .AIOHTTP_CLIENT_SESSION ,
981+ f"Direct use of aliased { chain [0 ]} .{ potential_alias } .{ func_name } () is forbidden, use SessionManager instead" ,
982+ )
851983
852984 return None
853985
@@ -903,12 +1035,35 @@ def _check_urllib3_call(
9031035 )
9041036 return None
9051037
1038+ def _check_aiohttp_call (
1039+ self , node : ast .Call , func_name : str
1040+ ) -> Optional [HTTPViolation ]:
1041+ """Check aiohttp module calls."""
1042+ if ModulePattern .is_aiohttp_session (func_name ):
1043+ return HTTPViolation (
1044+ self .filename ,
1045+ node .lineno ,
1046+ node .col_offset ,
1047+ ViolationType .AIOHTTP_CLIENT_SESSION ,
1048+ f"Direct use of aiohttp.{ func_name } () is forbidden, use SessionManager instead" ,
1049+ )
1050+ elif ModulePattern .is_aiohttp_api (func_name ):
1051+ return HTTPViolation (
1052+ self .filename ,
1053+ node .lineno ,
1054+ node .col_offset ,
1055+ ViolationType .AIOHTTP_REQUEST ,
1056+ f"Direct use of aiohttp.{ func_name } () is forbidden, use SessionManager instead" ,
1057+ )
1058+ return None
1059+
9061060
9071061class FileChecker :
9081062 """Handles file-level checking logic with proper glob path matching."""
9091063
9101064 EXEMPT_PATTERNS = [
9111065 "**/session_manager.py" ,
1066+ "**/_session_manager.py" ,
9121067 "**/vendored/**/*" ,
9131068 ]
9141069
@@ -1043,8 +1198,11 @@ def main():
10431198 print (
10441199 " - Replace urllib3.PoolManager/ProxyManager() with session from session_manager.use_session()"
10451200 )
1201+ print (
1202+ " - Replace aiohttp.ClientSession() with async SessionManager.use_session()"
1203+ )
10461204 print (" - Replace direct HTTP method imports with SessionManager usage" )
1047- print (" - Use SessionManager for all HTTP operations" )
1205+ print (" - Use SessionManager for all HTTP operations (sync and async) " )
10481206
10491207 print ()
10501208 print (f"Found { len (all_violations )} violation(s)" )
0 commit comments