Skip to content

Commit 9afba6b

Browse files
committed
support union types
1 parent 575340e commit 9afba6b

File tree

8 files changed

+146
-42
lines changed

8 files changed

+146
-42
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.1.1] - 2024-07-03
9+
10+
### Fixed
11+
- Added support for PEP 604 union type syntax (e.g., `str | None`) in FastAPI endpoints
12+
- Improved type handling in model field generation for newer Python versions (3.10+)
13+
- Fixed compatibility issues with modern type annotations in path parameters, query parameters, and Pydantic models
14+
815
## [0.1.0] - 2024-03-08
916

1017
### Added

examples/sample_app.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
)
1414

1515
class Item(BaseModel):
16-
id: Optional[int] = None
16+
id: int | None = None
1717
name: str = Field(..., description="The name of the item")
18-
description: Optional[str] = Field(None, description="The description of the item")
18+
description: str | None = Field(None, description="The description of the item")
1919
price: float = Field(..., description="The price of the item", gt=0)
20-
tax: Optional[float] = Field(None, description="The tax rate for the item")
20+
tax: float | None = Field(None, description="The tax rate for the item")
2121
tags: List[str] = Field(default_factory=list, description="Tags for the item")
2222

2323
# In-memory database
@@ -87,9 +87,9 @@ def delete_item(item_id: int):
8787

8888
@app.get("/items/search/", response_model=List[Item], tags=["search"])
8989
def search_items(
90-
q: Optional[str] = Query(None, description="Search query string"),
91-
min_price: Optional[float] = Query(None, description="Minimum price"),
92-
max_price: Optional[float] = Query(None, description="Maximum price"),
90+
q: str | None = Query(None, description="Search query string"),
91+
min_price: float | None = Query(None, description="Minimum price"),
92+
max_price: float | None = Query(None, description="Maximum price"),
9393
tags: List[str] = Query([], description="Filter by tags"),
9494
):
9595
"""

fastapi_mcp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
Created by Tadata Inc. (https://github.com/tadata-org)
55
"""
66

7-
__version__ = "0.1.0"
7+
__version__ = "0.1.1"

fastapi_mcp/converter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import inspect
6+
import sys
67
from typing import Any, Callable, Dict, List, Optional, Type, Union, get_type_hints
78

89
from fastapi import FastAPI
@@ -18,6 +19,9 @@
1819

1920
from fastapi_mcp.discovery import Endpoint
2021

22+
# Check if Python version supports PEP 604 (|) union types
23+
PY310_OR_HIGHER = sys.version_info >= (3, 10)
24+
2125

2226
def convert_endpoint_to_mcp_tool(endpoint: Endpoint) -> Dict[str, Any]:
2327
"""
@@ -220,6 +224,20 @@ def _convert_type_annotation(annotation: Any) -> str:
220224
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
221225
return "dict"
222226

227+
# Handle PEP 604 union types (X | Y) in Python 3.10+
228+
if PY310_OR_HIGHER:
229+
# Check if the annotation has the __or__ method, indicating it's a type that can be used with |
230+
# or if it has the __origin__ attribute that matches types.UnionType
231+
if (hasattr(annotation, "__or__") or
232+
(hasattr(annotation, "__origin__") and str(annotation.__origin__) == "types.UnionType")):
233+
# Get the args of the union
234+
args = getattr(annotation, "__args__", [])
235+
# Filter out NoneType to handle Optional types
236+
for arg in args:
237+
if arg is not type(None): # noqa
238+
return _convert_type_annotation(arg)
239+
return "Any"
240+
223241
# Try to handle Union types
224242
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
225243
for arg in getattr(annotation, "__args__", []):

fastapi_mcp/discovery.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import inspect
6+
import sys
67
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, get_type_hints, Union
78

89
import fastapi
@@ -15,6 +16,9 @@
1516
# Pydantic v1 fallback
1617
from pydantic import BaseModel
1718

19+
# Check if Python version supports PEP 604 (|) union types
20+
PY310_OR_HIGHER = sys.version_info >= (3, 10)
21+
1822

1923
class Endpoint:
2024
"""
@@ -215,6 +219,18 @@ def parse_endpoint_params(route: APIRoute) -> Tuple[Dict[str, Any], Dict[str, An
215219
else:
216220
# Simplify the annotation if it's a complex type
217221
simplified_annotation = annotation
222+
223+
# Handle PEP 604 union types (X | Y) in Python 3.10+
224+
if PY310_OR_HIGHER:
225+
if (hasattr(annotation, "__or__") or
226+
(hasattr(annotation, "__origin__") and str(annotation.__origin__) == "types.UnionType")):
227+
args = getattr(annotation, "__args__", [])
228+
for arg in args:
229+
if arg is not type(None): # noqa
230+
simplified_annotation = arg
231+
break
232+
233+
# Handle traditional Union types
218234
if hasattr(annotation, "__origin__") and annotation.__origin__ is Union:
219235
for arg in getattr(annotation, "__args__", []):
220236
if arg is not type(None): # noqa

fastapi_mcp/generator.py

Lines changed: 96 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import re
7+
import sys
78
from pathlib import Path
89
from typing import Dict, List, Optional, Union, Any, Type, get_type_hints, get_origin, get_args
910

@@ -19,6 +20,9 @@
1920
from fastapi_mcp.converter import convert_endpoint_to_mcp_tool
2021
from fastapi_mcp.discovery import Endpoint
2122

23+
# Check if Python version supports PEP 604 (|) union types
24+
PY310_OR_HIGHER = sys.version_info >= (3, 10)
25+
2226

2327
def generate_mcp_server(
2428
app: FastAPI,
@@ -172,33 +176,27 @@ def generate_server_code(
172176
fields = get_model_fields(model_class)
173177
for field_name, field_info in fields.items():
174178
# Get type name
175-
if hasattr(field_info["type"], "__origin__") and field_info["type"].__origin__ is Union:
176-
args = field_info["type"].__args__
177-
if len(args) == 2 and args[1] is type(None):
178-
type_name = _get_simple_type_name(args[0])
179+
type_name = _get_simple_type_name(field_info["type"])
180+
181+
# Check if the field is optional
182+
is_optional = field_info.get("optional", False)
183+
184+
# Check if the field has a default value
185+
if field_info["default"] is None and field_info["required"]:
186+
# Required field with no default value - use Undefined
187+
model_lines.append(f" {field_name}: {type_name} = Undefined")
188+
elif field_info["default"] is None and not field_info["required"]:
189+
# Optional field with None default
190+
if is_optional:
179191
model_lines.append(f" {field_name}: Optional[{type_name}] = None")
180192
else:
181-
type_name = _get_simple_type_name(field_info["type"])
182-
model_lines.append(f" {field_name}: {type_name}")
193+
model_lines.append(f" {field_name}: {type_name} = None")
194+
elif isinstance(field_info["default"], str):
195+
# String default value
196+
model_lines.append(f' {field_name}: {type_name} = "{field_info["default"]}"')
183197
else:
184-
type_name = _get_simple_type_name(field_info["type"])
185-
186-
# Check if the field has a default value
187-
if field_info["default"] is None and field_info["required"]:
188-
# Required field with no default value - use Undefined
189-
model_lines.append(f" {field_name}: {type_name} = Undefined")
190-
elif field_info["default"] is None and not field_info["required"]:
191-
# Optional field with None default
192-
if type_name.startswith("Optional["):
193-
model_lines.append(f" {field_name}: {type_name} = None")
194-
else:
195-
model_lines.append(f" {field_name}: Optional[{type_name}] = None")
196-
elif isinstance(field_info["default"], str):
197-
# String default value
198-
model_lines.append(f' {field_name}: {type_name} = "{field_info["default"]}"')
199-
else:
200-
# Other default value
201-
model_lines.append(f" {field_name}: {type_name} = {field_info['default']}")
198+
# Other default value
199+
model_lines.append(f" {field_name}: {type_name} = {field_info['default']}")
202200

203201
# Add the model definition
204202
server_code.extend(model_lines)
@@ -469,31 +467,75 @@ def _get_simple_type_name(type_annotation):
469467
Returns:
470468
A string representation of the type.
471469
"""
470+
# Handle primitive types directly
471+
if type_annotation is str:
472+
return "str"
473+
elif type_annotation is int:
474+
return "int"
475+
elif type_annotation is float:
476+
return "float"
477+
elif type_annotation is bool:
478+
return "bool"
479+
elif type_annotation is list:
480+
return "List"
481+
elif type_annotation is dict:
482+
return "Dict"
483+
elif type_annotation is Any:
484+
return "Any"
485+
486+
# Handle PEP 604 union types (X | Y) in Python 3.10+
487+
if PY310_OR_HIGHER:
488+
if (hasattr(type_annotation, "__or__") or
489+
(hasattr(type_annotation, "__origin__") and str(type_annotation.__origin__) == "types.UnionType")):
490+
args = getattr(type_annotation, "__args__", [])
491+
492+
# Check if this is equivalent to Optional[T]
493+
if any(arg is type(None) for arg in args):
494+
# Get the non-None type
495+
non_none_args = [arg for arg in args if arg is not type(None)]
496+
if len(non_none_args) == 1:
497+
return f"Optional[{_get_simple_type_name(non_none_args[0])}]"
498+
499+
# Regular Union
500+
arg_strs = [_get_simple_type_name(arg) for arg in args]
501+
return f"Union[{', '.join(arg_strs)}]"
502+
472503
if hasattr(type_annotation, "__origin__"):
473504
# Handle generics like List, Dict, etc.
474505
origin = get_origin(type_annotation)
475506
args = get_args(type_annotation)
476507

477-
if origin is list or origin is List:
508+
if origin is list or str(origin).endswith("list"):
478509
if args:
479510
return f"List[{_get_simple_type_name(args[0])}]"
480511
return "List"
481-
elif origin is dict or origin is Dict:
512+
elif origin is dict or str(origin).endswith("dict"):
482513
if len(args) == 2:
483514
return f"Dict[{_get_simple_type_name(args[0])}, {_get_simple_type_name(args[1])}]"
484515
return "Dict"
485-
elif origin is Union:
516+
elif origin is Union or str(origin).endswith("Union"):
517+
# Check if this is equivalent to Optional[T]
518+
if len(args) == 2 and args[1] is type(None): # noqa
519+
return f"Optional[{_get_simple_type_name(args[0])}]"
520+
521+
# Regular Union
486522
arg_strs = [_get_simple_type_name(arg) for arg in args]
487523
return f"Union[{', '.join(arg_strs)}]"
488524
else:
489525
# Other generic types
490-
return str(type_annotation).replace("typing.", "")
526+
try:
527+
return str(type_annotation).replace("typing.", "")
528+
except:
529+
return "Any"
491530
elif hasattr(type_annotation, "__name__"):
492531
# Regular classes
493532
return type_annotation.__name__
494533
else:
495534
# Fallback
496-
return str(type_annotation).replace("typing.", "")
535+
try:
536+
return str(type_annotation).replace("typing.", "")
537+
except:
538+
return "Any"
497539

498540

499541
def get_model_fields(model_class):
@@ -531,19 +573,40 @@ def get_model_fields(model_class):
531573
except:
532574
pass
533575

534-
# Check for Optional type
576+
# Check for Optional type (PEP 604 union types)
535577
is_optional = False
578+
clean_type = field_type # Store the cleaned type
579+
580+
# Handle PEP 604 union types (X | Y) in Python 3.10+
581+
if PY310_OR_HIGHER:
582+
if (hasattr(field_type, "__or__") or
583+
(hasattr(field_type, "__origin__") and str(field_type.__origin__) == "types.UnionType")):
584+
args = getattr(field_type, "__args__", [])
585+
if any(arg is type(None) for arg in args):
586+
is_optional = True
587+
required = False
588+
# Extract the non-None type
589+
non_none_args = [arg for arg in args if arg is not type(None)]
590+
if len(non_none_args) == 1:
591+
clean_type = non_none_args[0]
592+
593+
# Check for traditional Union with None type
536594
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
537595
args = field_type.__args__
538-
if len(args) == 2 and args[1] is type(None):
596+
if any(arg is type(None) for arg in args):
539597
is_optional = True
540598
required = False
599+
# Extract the non-None type
600+
non_none_args = [arg for arg in args if arg is not type(None)]
601+
if len(non_none_args) == 1:
602+
clean_type = non_none_args[0]
541603

542604
# Add the field info
543605
fields[field_name] = {
544-
"type": field_type,
606+
"type": clean_type,
545607
"required": required and not is_optional,
546-
"default": default
608+
"default": default,
609+
"optional": is_optional
547610
}
548611

549612
return fields

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "fastapi-mcp"
7-
version = "0.1.0"
7+
version = "0.1.1"
88
description = "Automatic MCP server generator for FastAPI applications - converts FastAPI endpoints to MCP tools for LLM integration"
99
readme = "README.md"
1010
requires-python = ">=3.10"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="fastapi-mcp",
5-
version="0.1.0",
5+
version="0.1.1",
66
description="Automatic MCP server generator for FastAPI applications - converts FastAPI endpoints to MCP tools for LLM integration",
77
author="Tadata Inc.",
88
author_email="[email protected]",

0 commit comments

Comments
 (0)