Skip to content

Commit 980d9b4

Browse files
committed
handle unsupported pydantic types, defered imports
1 parent 2fcf58f commit 980d9b4

File tree

1 file changed

+70
-32
lines changed

1 file changed

+70
-32
lines changed

src/serverless_openapi_generator/pydantic_handler.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,25 @@
99
import yaml
1010
from rich import print as rprint
1111
from pydantic import BaseModel
12+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
13+
from pydantic_core import core_schema
1214

1315
try:
1416
from pydantic.errors import PydanticInvalidForJsonSchema
1517
except ImportError:
1618
PydanticInvalidForJsonSchema = Exception
1719

20+
# --- Custom JSON Schema Generator to handle unsupported types ---
21+
class CustomJsonSchemaGenerator(GenerateJsonSchema):
22+
def core_schema_schema(self, core_schema: core_schema.CoreSchema) -> JsonSchemaValue:
23+
# Handle aws-lambda-powertools types that Pydantic can't process by default
24+
if isinstance(core_schema, core_schema.IsInstanceSchema):
25+
if 'CaseInsensitiveDict' in str(core_schema.cls) or 'Cookie' in str(core_schema.cls):
26+
return {'type': 'object'}
27+
28+
# Fallback to the default behavior for all other types
29+
return super().core_schema_schema(core_schema)
30+
1831

1932
def is_pydantic_model(obj):
2033
"""Checks if an object is a Pydantic model class, excluding BaseModel itself."""
@@ -57,38 +70,61 @@ def generate_dto_schemas(source_dir: Path, output_dir: Path, project_root: Path)
5770
discovered_models = []
5871
processed_dto_files = set()
5972
successfully_generated_schemas = {}
73+
74+
dto_files = list(source_dir.rglob("**/dtos.py"))
75+
76+
for pass_num in range(3): # Try to resolve imports up to 3 times
77+
if not dto_files:
78+
break
79+
80+
rprint(f"\n[bold]Import Pass {pass_num + 1}...[/bold]")
81+
82+
remaining_files = []
83+
84+
for dto_file_path in dto_files:
85+
if dto_file_path in processed_dto_files:
86+
continue
6087

61-
for dto_file_path in source_dir.rglob("**/dtos.py"):
62-
if dto_file_path in processed_dto_files:
63-
continue
64-
processed_dto_files.add(dto_file_path)
65-
66-
rprint(f" [cyan]Processing DTO file: {dto_file_path}[/cyan]")
67-
relative_path = dto_file_path.relative_to(import_root)
68-
module_name_parts = list(relative_path.parts)
69-
if module_name_parts[-1] == "dtos.py":
70-
module_name_parts[-1] = "dtos"
71-
module_name = ".".join(part for part in module_name_parts if part != "__pycache__")
72-
73-
try:
74-
spec = importlib.util.spec_from_file_location(module_name, dto_file_path)
75-
if spec and spec.loader:
76-
module = importlib.util.module_from_spec(spec)
77-
sys.modules[module_name] = module
78-
spec.loader.exec_module(module)
79-
else:
80-
rprint(f"\t[yellow]Could not create module spec for {dto_file_path}[/yellow]")
88+
rprint(f" [cyan]Processing DTO file: {dto_file_path}[/cyan]")
89+
relative_path = dto_file_path.relative_to(import_root)
90+
module_name_parts = list(relative_path.parts)
91+
if module_name_parts[-1] == "dtos.py":
92+
module_name_parts[-1] = "dtos"
93+
module_name = ".".join(part for part in module_name_parts if part != "__pycache__")
94+
95+
try:
96+
spec = importlib.util.spec_from_file_location(module_name, dto_file_path)
97+
if spec and spec.loader:
98+
module = importlib.util.module_from_spec(spec)
99+
sys.modules[module_name] = module
100+
spec.loader.exec_module(module)
101+
else:
102+
rprint(f"\t[yellow]Could not create module spec for {dto_file_path}[/yellow]")
103+
continue
104+
except ImportError as e:
105+
rprint(f"\t[yellow]Deferring import of {dto_file_path} due to ImportError: {e}[/yellow]")
106+
remaining_files.append(dto_file_path)
107+
if module_name in sys.modules:
108+
del sys.modules[module_name]
81109
continue
82-
except Exception as e:
83-
rprint(f"\t[red]Error importing module {module_name} from {dto_file_path}: {e}[/red]")
84-
if module_name in sys.modules:
85-
del sys.modules[module_name]
86-
continue
110+
except Exception as e:
111+
rprint(f"\t[red]Error importing module {module_name} from {dto_file_path}: {e}[/red]")
112+
if module_name in sys.modules:
113+
del sys.modules[module_name]
114+
continue
115+
116+
processed_dto_files.add(dto_file_path)
117+
for name, obj in inspect.getmembers(module):
118+
if is_pydantic_model(obj):
119+
if hasattr(obj, "__module__") and obj.__module__ == module_name:
120+
discovered_models.append((obj, name, module_name))
121+
122+
dto_files = remaining_files
87123

88-
for name, obj in inspect.getmembers(module):
89-
if is_pydantic_model(obj):
90-
if hasattr(obj, "__module__") and obj.__module__ == module_name:
91-
discovered_models.append((obj, name, module_name))
124+
if dto_files:
125+
rprint("\n[bold red]Could not resolve all imports after multiple passes. The following files failed:[/bold red]")
126+
for f in dto_files:
127+
rprint(f" - {f}")
92128

93129
rprint(f"\n[bold]Found {len(discovered_models)} Pydantic models from {len(processed_dto_files)} DTO file(s).[/bold]")
94130

@@ -114,7 +150,7 @@ def generate_dto_schemas(source_dir: Path, output_dir: Path, project_root: Path)
114150
rprint(f" [cyan]Generating schema for: {module_name}.{model_name}[/cyan]")
115151
try:
116152
if hasattr(model_class, "model_json_schema"):
117-
schema = model_class.model_json_schema()
153+
schema = model_class.model_json_schema(schema_generator=CustomJsonSchemaGenerator)
118154
elif hasattr(model_class, "schema_json"):
119155
schema = json.loads(model_class.schema_json())
120156
else:
@@ -185,8 +221,10 @@ def generate_serverless_config(successfully_generated_schemas, project_meta, pro
185221
rprint("\n[bold]Generating Serverless config for OpenAPI in memory...[/bold]")
186222
python_runtime = "python3.12"
187223
try:
188-
main_sls_file = project_root / "serverless-wo-cross-accounts.yml"
189-
if main_sls_file.exists():
224+
# Look for any serverless.yml or serverless.yaml file in the project root
225+
sls_files = list(project_root.glob("serverless.y*ml"))
226+
if sls_files:
227+
main_sls_file = sls_files[0]
190228
with open(main_sls_file, "r") as f:
191229
main_config = yaml.safe_load(f)
192230
if main_config and "provider" in main_config and "runtime" in main_config["provider"]:

0 commit comments

Comments
 (0)