|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Pre-commit hook to ensure optional dependencies are always imported from .options module. |
| 4 | +This ensures that the connector can operate in environments where these optional libraries are not available. |
| 5 | +""" |
| 6 | +import argparse |
| 7 | +import ast |
| 8 | +import sys |
| 9 | +from dataclasses import dataclass |
| 10 | +from pathlib import Path |
| 11 | +from typing import List |
| 12 | + |
| 13 | +CHECKED_MODULES = ["boto3", "botocore", "pandas", "pyarrow", "keyring"] |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class ImportViolation: |
| 18 | + """Pretty prints a violation import restrictions.""" |
| 19 | + |
| 20 | + filename: str |
| 21 | + line: int |
| 22 | + col: int |
| 23 | + message: str |
| 24 | + |
| 25 | + def __str__(self): |
| 26 | + return f"{self.filename}:{self.line}:{self.col}: {self.message}" |
| 27 | + |
| 28 | + |
| 29 | +class ImportChecker(ast.NodeVisitor): |
| 30 | + """Checks that optional imports are only imported from .options module.""" |
| 31 | + |
| 32 | + def __init__(self, filename: str): |
| 33 | + self.filename = filename |
| 34 | + self.violations: List[ImportViolation] = [] |
| 35 | + |
| 36 | + def visit_If(self, node: ast.If): |
| 37 | + # Always visit the condition, but ignore imports inside "if TYPE_CHECKING:" blocks |
| 38 | + if getattr(node.test, "id", None) == "TYPE_CHECKING": |
| 39 | + # Skip the body and orelse for TYPE_CHECKING blocks |
| 40 | + pass |
| 41 | + else: |
| 42 | + self.generic_visit(node) |
| 43 | + |
| 44 | + def visit_Import(self, node: ast.Import): |
| 45 | + """Check import statements.""" |
| 46 | + for alias in node.names: |
| 47 | + self._check_import(alias.name, node.lineno, node.col_offset) |
| 48 | + self.generic_visit(node) |
| 49 | + |
| 50 | + def visit_ImportFrom(self, node: ast.ImportFrom): |
| 51 | + """Check from...import statements.""" |
| 52 | + if node.module: |
| 53 | + # Check if importing from a checked module directly |
| 54 | + for module in CHECKED_MODULES: |
| 55 | + if node.module.startswith(module): |
| 56 | + self.violations.append( |
| 57 | + ImportViolation( |
| 58 | + self.filename, |
| 59 | + node.lineno, |
| 60 | + node.col_offset, |
| 61 | + f"Import from '{node.module}' is not allowed. Use 'from .options import {module}' instead", |
| 62 | + ) |
| 63 | + ) |
| 64 | + |
| 65 | + # Check if importing checked modules from .options (this is allowed) |
| 66 | + if node.module == ".options": |
| 67 | + # This is the correct way to import these modules |
| 68 | + pass |
| 69 | + self.generic_visit(node) |
| 70 | + |
| 71 | + def _check_import(self, module_name: str, line: int, col: int): |
| 72 | + """Check if a module import is for checked modules and not from .options.""" |
| 73 | + for module in CHECKED_MODULES: |
| 74 | + if module_name.startswith(module): |
| 75 | + self.violations.append( |
| 76 | + ImportViolation( |
| 77 | + self.filename, |
| 78 | + line, |
| 79 | + col, |
| 80 | + f"Direct import of '{module_name}' is not allowed. Use 'from .options import {module}' instead", |
| 81 | + ) |
| 82 | + ) |
| 83 | + break |
| 84 | + |
| 85 | + |
| 86 | +def check_file(filename: str) -> List[ImportViolation]: |
| 87 | + """Check a file for optional import violations.""" |
| 88 | + try: |
| 89 | + tree = ast.parse(Path(filename).read_text()) |
| 90 | + except SyntaxError: |
| 91 | + # gracefully handle syntax errors |
| 92 | + return [] |
| 93 | + checker = ImportChecker(filename) |
| 94 | + checker.visit(tree) |
| 95 | + return checker.violations |
| 96 | + |
| 97 | + |
| 98 | +def main(): |
| 99 | + """Main function for pre-commit hook.""" |
| 100 | + parser = argparse.ArgumentParser( |
| 101 | + description="Check that optional imports are only imported from .options module" |
| 102 | + ) |
| 103 | + parser.add_argument("filenames", nargs="*", help="Filenames to check") |
| 104 | + parser.add_argument( |
| 105 | + "--show-fixes", action="store_true", help="Show suggested fixes" |
| 106 | + ) |
| 107 | + args = parser.parse_args() |
| 108 | + |
| 109 | + all_violations = [] |
| 110 | + for filename in args.filenames: |
| 111 | + if not filename.endswith(".py"): |
| 112 | + continue |
| 113 | + all_violations.extend(check_file(filename)) |
| 114 | + |
| 115 | + # Show violations |
| 116 | + if all_violations: |
| 117 | + print("Optional import violations found:") |
| 118 | + print() |
| 119 | + |
| 120 | + for violation in all_violations: |
| 121 | + print(f" {violation}") |
| 122 | + |
| 123 | + if args.show_fixes: |
| 124 | + print() |
| 125 | + print("How to fix:") |
| 126 | + print(" - Import optional modules only from .options module") |
| 127 | + print(" - Example:") |
| 128 | + print(" # CORRECT:") |
| 129 | + print(" from .options import boto3, botocore, installed_boto") |
| 130 | + print(" if installed_boto:") |
| 131 | + print(" SigV4Auth = botocore.auth.SigV4Auth") |
| 132 | + print() |
| 133 | + print(" # INCORRECT:") |
| 134 | + print(" import boto3") |
| 135 | + print(" from botocore.auth import SigV4Auth") |
| 136 | + print() |
| 137 | + print( |
| 138 | + " - This ensures the connector works in environments where optional libraries are not installed" |
| 139 | + ) |
| 140 | + |
| 141 | + print() |
| 142 | + print(f"Found {len(all_violations)} violation(s)") |
| 143 | + return 1 |
| 144 | + |
| 145 | + return 0 |
| 146 | + |
| 147 | + |
| 148 | +if __name__ == "__main__": |
| 149 | + sys.exit(main()) |
0 commit comments