|  | 
|  | 1 | +#!/usr/bin/env python3 | 
|  | 2 | +""" | 
|  | 3 | +Pre-commit hook to ensure optional dependencies are always imported from .options module. | 
|  | 4 | +This ensures that the connector can operate in environments where these optional libraries are not available. | 
|  | 5 | +""" | 
|  | 6 | +import argparse | 
|  | 7 | +import ast | 
|  | 8 | +import sys | 
|  | 9 | +from dataclasses import dataclass | 
|  | 10 | +from pathlib import Path | 
|  | 11 | +from typing import List | 
|  | 12 | + | 
|  | 13 | +CHECKED_MODULES = ["boto3", "botocore", "pandas", "pyarrow", "keyring"] | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +@dataclass(frozen=True) | 
|  | 17 | +class ImportViolation: | 
|  | 18 | +    """Pretty prints a violation import restrictions.""" | 
|  | 19 | + | 
|  | 20 | +    filename: str | 
|  | 21 | +    line: int | 
|  | 22 | +    col: int | 
|  | 23 | +    message: str | 
|  | 24 | + | 
|  | 25 | +    def __str__(self): | 
|  | 26 | +        return f"{self.filename}:{self.line}:{self.col}: {self.message}" | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +class ImportChecker(ast.NodeVisitor): | 
|  | 30 | +    """Checks that optional imports are only imported from .options module.""" | 
|  | 31 | + | 
|  | 32 | +    def __init__(self, filename: str): | 
|  | 33 | +        self.filename = filename | 
|  | 34 | +        self.violations: List[ImportViolation] = [] | 
|  | 35 | + | 
|  | 36 | +    def visit_If(self, node: ast.If): | 
|  | 37 | +        # Always visit the condition, but ignore imports inside "if TYPE_CHECKING:" blocks | 
|  | 38 | +        if getattr(node.test, "id", None) == "TYPE_CHECKING": | 
|  | 39 | +            # Skip the body and orelse for TYPE_CHECKING blocks | 
|  | 40 | +            pass | 
|  | 41 | +        else: | 
|  | 42 | +            self.generic_visit(node) | 
|  | 43 | + | 
|  | 44 | +    def visit_Import(self, node: ast.Import): | 
|  | 45 | +        """Check import statements.""" | 
|  | 46 | +        for alias in node.names: | 
|  | 47 | +            self._check_import(alias.name, node.lineno, node.col_offset) | 
|  | 48 | +        self.generic_visit(node) | 
|  | 49 | + | 
|  | 50 | +    def visit_ImportFrom(self, node: ast.ImportFrom): | 
|  | 51 | +        """Check from...import statements.""" | 
|  | 52 | +        if node.module: | 
|  | 53 | +            # Check if importing from a checked module directly | 
|  | 54 | +            for module in CHECKED_MODULES: | 
|  | 55 | +                if node.module.startswith(module): | 
|  | 56 | +                    self.violations.append( | 
|  | 57 | +                        ImportViolation( | 
|  | 58 | +                            self.filename, | 
|  | 59 | +                            node.lineno, | 
|  | 60 | +                            node.col_offset, | 
|  | 61 | +                            f"Import from '{node.module}' is not allowed. Use 'from .options import {module}' instead", | 
|  | 62 | +                        ) | 
|  | 63 | +                    ) | 
|  | 64 | + | 
|  | 65 | +            # Check if importing checked modules from .options (this is allowed) | 
|  | 66 | +            if node.module == ".options": | 
|  | 67 | +                # This is the correct way to import these modules | 
|  | 68 | +                pass | 
|  | 69 | +        self.generic_visit(node) | 
|  | 70 | + | 
|  | 71 | +    def _check_import(self, module_name: str, line: int, col: int): | 
|  | 72 | +        """Check if a module import is for checked modules and not from .options.""" | 
|  | 73 | +        for module in CHECKED_MODULES: | 
|  | 74 | +            if module_name.startswith(module): | 
|  | 75 | +                self.violations.append( | 
|  | 76 | +                    ImportViolation( | 
|  | 77 | +                        self.filename, | 
|  | 78 | +                        line, | 
|  | 79 | +                        col, | 
|  | 80 | +                        f"Direct import of '{module_name}' is not allowed. Use 'from .options import {module}' instead", | 
|  | 81 | +                    ) | 
|  | 82 | +                ) | 
|  | 83 | +                break | 
|  | 84 | + | 
|  | 85 | + | 
|  | 86 | +def check_file(filename: str) -> List[ImportViolation]: | 
|  | 87 | +    """Check a file for optional import violations.""" | 
|  | 88 | +    try: | 
|  | 89 | +        tree = ast.parse(Path(filename).read_text()) | 
|  | 90 | +    except SyntaxError: | 
|  | 91 | +        # gracefully handle syntax errors | 
|  | 92 | +        return [] | 
|  | 93 | +    checker = ImportChecker(filename) | 
|  | 94 | +    checker.visit(tree) | 
|  | 95 | +    return checker.violations | 
|  | 96 | + | 
|  | 97 | + | 
|  | 98 | +def main(): | 
|  | 99 | +    """Main function for pre-commit hook.""" | 
|  | 100 | +    parser = argparse.ArgumentParser( | 
|  | 101 | +        description="Check that optional imports are only imported from .options module" | 
|  | 102 | +    ) | 
|  | 103 | +    parser.add_argument("filenames", nargs="*", help="Filenames to check") | 
|  | 104 | +    parser.add_argument( | 
|  | 105 | +        "--show-fixes", action="store_true", help="Show suggested fixes" | 
|  | 106 | +    ) | 
|  | 107 | +    args = parser.parse_args() | 
|  | 108 | + | 
|  | 109 | +    all_violations = [] | 
|  | 110 | +    for filename in args.filenames: | 
|  | 111 | +        if not filename.endswith(".py"): | 
|  | 112 | +            continue | 
|  | 113 | +        all_violations.extend(check_file(filename)) | 
|  | 114 | + | 
|  | 115 | +    # Show violations | 
|  | 116 | +    if all_violations: | 
|  | 117 | +        print("Optional import violations found:") | 
|  | 118 | +        print() | 
|  | 119 | + | 
|  | 120 | +        for violation in all_violations: | 
|  | 121 | +            print(f"  {violation}") | 
|  | 122 | + | 
|  | 123 | +        if args.show_fixes: | 
|  | 124 | +            print() | 
|  | 125 | +            print("How to fix:") | 
|  | 126 | +            print("  - Import optional modules only from .options module") | 
|  | 127 | +            print("  - Example:") | 
|  | 128 | +            print("    # CORRECT:") | 
|  | 129 | +            print("    from .options import boto3, botocore, installed_boto") | 
|  | 130 | +            print("    if installed_boto:") | 
|  | 131 | +            print("        SigV4Auth = botocore.auth.SigV4Auth") | 
|  | 132 | +            print() | 
|  | 133 | +            print("    # INCORRECT:") | 
|  | 134 | +            print("    import boto3") | 
|  | 135 | +            print("    from botocore.auth import SigV4Auth") | 
|  | 136 | +            print() | 
|  | 137 | +            print( | 
|  | 138 | +                "  - This ensures the connector works in environments where optional libraries are not installed" | 
|  | 139 | +            ) | 
|  | 140 | + | 
|  | 141 | +        print() | 
|  | 142 | +        print(f"Found {len(all_violations)} violation(s)") | 
|  | 143 | +        return 1 | 
|  | 144 | + | 
|  | 145 | +    return 0 | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +if __name__ == "__main__": | 
|  | 149 | +    sys.exit(main()) | 
0 commit comments