44
55import os
66import re
7+ import sys
78from pathlib import Path
89from typing import Dict , List , Optional , Union , Any , Type , get_type_hints , get_origin , get_args
910
1920from fastapi_mcp .converter import convert_endpoint_to_mcp_tool
2021from fastapi_mcp .discovery import Endpoint
2122
23+ # Check if Python version supports PEP 604 (|) union types
24+ PY310_OR_HIGHER = sys .version_info >= (3 , 10 )
25+
2226
2327def generate_mcp_server (
2428 app : FastAPI ,
@@ -172,33 +176,27 @@ def generate_server_code(
172176 fields = get_model_fields (model_class )
173177 for field_name , field_info in fields .items ():
174178 # Get type name
175- if hasattr (field_info ["type" ], "__origin__" ) and field_info ["type" ].__origin__ is Union :
176- args = field_info ["type" ].__args__
177- if len (args ) == 2 and args [1 ] is type (None ):
178- type_name = _get_simple_type_name (args [0 ])
179+ type_name = _get_simple_type_name (field_info ["type" ])
180+
181+ # Check if the field is optional
182+ is_optional = field_info .get ("optional" , False )
183+
184+ # Check if the field has a default value
185+ if field_info ["default" ] is None and field_info ["required" ]:
186+ # Required field with no default value - use Undefined
187+ model_lines .append (f" { field_name } : { type_name } = Undefined" )
188+ elif field_info ["default" ] is None and not field_info ["required" ]:
189+ # Optional field with None default
190+ if is_optional :
179191 model_lines .append (f" { field_name } : Optional[{ type_name } ] = None" )
180192 else :
181- type_name = _get_simple_type_name (field_info ["type" ])
182- model_lines .append (f" { field_name } : { type_name } " )
193+ model_lines .append (f" { field_name } : { type_name } = None" )
194+ elif isinstance (field_info ["default" ], str ):
195+ # String default value
196+ model_lines .append (f' { field_name } : { type_name } = "{ field_info ["default" ]} "' )
183197 else :
184- type_name = _get_simple_type_name (field_info ["type" ])
185-
186- # Check if the field has a default value
187- if field_info ["default" ] is None and field_info ["required" ]:
188- # Required field with no default value - use Undefined
189- model_lines .append (f" { field_name } : { type_name } = Undefined" )
190- elif field_info ["default" ] is None and not field_info ["required" ]:
191- # Optional field with None default
192- if type_name .startswith ("Optional[" ):
193- model_lines .append (f" { field_name } : { type_name } = None" )
194- else :
195- model_lines .append (f" { field_name } : Optional[{ type_name } ] = None" )
196- elif isinstance (field_info ["default" ], str ):
197- # String default value
198- model_lines .append (f' { field_name } : { type_name } = "{ field_info ["default" ]} "' )
199- else :
200- # Other default value
201- model_lines .append (f" { field_name } : { type_name } = { field_info ['default' ]} " )
198+ # Other default value
199+ model_lines .append (f" { field_name } : { type_name } = { field_info ['default' ]} " )
202200
203201 # Add the model definition
204202 server_code .extend (model_lines )
@@ -469,31 +467,75 @@ def _get_simple_type_name(type_annotation):
469467 Returns:
470468 A string representation of the type.
471469 """
470+ # Handle primitive types directly
471+ if type_annotation is str :
472+ return "str"
473+ elif type_annotation is int :
474+ return "int"
475+ elif type_annotation is float :
476+ return "float"
477+ elif type_annotation is bool :
478+ return "bool"
479+ elif type_annotation is list :
480+ return "List"
481+ elif type_annotation is dict :
482+ return "Dict"
483+ elif type_annotation is Any :
484+ return "Any"
485+
486+ # Handle PEP 604 union types (X | Y) in Python 3.10+
487+ if PY310_OR_HIGHER :
488+ if (hasattr (type_annotation , "__or__" ) or
489+ (hasattr (type_annotation , "__origin__" ) and str (type_annotation .__origin__ ) == "types.UnionType" )):
490+ args = getattr (type_annotation , "__args__" , [])
491+
492+ # Check if this is equivalent to Optional[T]
493+ if any (arg is type (None ) for arg in args ):
494+ # Get the non-None type
495+ non_none_args = [arg for arg in args if arg is not type (None )]
496+ if len (non_none_args ) == 1 :
497+ return f"Optional[{ _get_simple_type_name (non_none_args [0 ])} ]"
498+
499+ # Regular Union
500+ arg_strs = [_get_simple_type_name (arg ) for arg in args ]
501+ return f"Union[{ ', ' .join (arg_strs )} ]"
502+
472503 if hasattr (type_annotation , "__origin__" ):
473504 # Handle generics like List, Dict, etc.
474505 origin = get_origin (type_annotation )
475506 args = get_args (type_annotation )
476507
477- if origin is list or origin is List :
508+ if origin is list or str ( origin ). endswith ( "list" ) :
478509 if args :
479510 return f"List[{ _get_simple_type_name (args [0 ])} ]"
480511 return "List"
481- elif origin is dict or origin is Dict :
512+ elif origin is dict or str ( origin ). endswith ( "dict" ) :
482513 if len (args ) == 2 :
483514 return f"Dict[{ _get_simple_type_name (args [0 ])} , { _get_simple_type_name (args [1 ])} ]"
484515 return "Dict"
485- elif origin is Union :
516+ elif origin is Union or str (origin ).endswith ("Union" ):
517+ # Check if this is equivalent to Optional[T]
518+ if len (args ) == 2 and args [1 ] is type (None ): # noqa
519+ return f"Optional[{ _get_simple_type_name (args [0 ])} ]"
520+
521+ # Regular Union
486522 arg_strs = [_get_simple_type_name (arg ) for arg in args ]
487523 return f"Union[{ ', ' .join (arg_strs )} ]"
488524 else :
489525 # Other generic types
490- return str (type_annotation ).replace ("typing." , "" )
526+ try :
527+ return str (type_annotation ).replace ("typing." , "" )
528+ except :
529+ return "Any"
491530 elif hasattr (type_annotation , "__name__" ):
492531 # Regular classes
493532 return type_annotation .__name__
494533 else :
495534 # Fallback
496- return str (type_annotation ).replace ("typing." , "" )
535+ try :
536+ return str (type_annotation ).replace ("typing." , "" )
537+ except :
538+ return "Any"
497539
498540
499541def get_model_fields (model_class ):
@@ -531,19 +573,40 @@ def get_model_fields(model_class):
531573 except :
532574 pass
533575
534- # Check for Optional type
576+ # Check for Optional type (PEP 604 union types)
535577 is_optional = False
578+ clean_type = field_type # Store the cleaned type
579+
580+ # Handle PEP 604 union types (X | Y) in Python 3.10+
581+ if PY310_OR_HIGHER :
582+ if (hasattr (field_type , "__or__" ) or
583+ (hasattr (field_type , "__origin__" ) and str (field_type .__origin__ ) == "types.UnionType" )):
584+ args = getattr (field_type , "__args__" , [])
585+ if any (arg is type (None ) for arg in args ):
586+ is_optional = True
587+ required = False
588+ # Extract the non-None type
589+ non_none_args = [arg for arg in args if arg is not type (None )]
590+ if len (non_none_args ) == 1 :
591+ clean_type = non_none_args [0 ]
592+
593+ # Check for traditional Union with None type
536594 if hasattr (field_type , "__origin__" ) and field_type .__origin__ is Union :
537595 args = field_type .__args__
538- if len ( args ) == 2 and args [ 1 ] is type (None ):
596+ if any ( arg is type (None ) for arg in args ):
539597 is_optional = True
540598 required = False
599+ # Extract the non-None type
600+ non_none_args = [arg for arg in args if arg is not type (None )]
601+ if len (non_none_args ) == 1 :
602+ clean_type = non_none_args [0 ]
541603
542604 # Add the field info
543605 fields [field_name ] = {
544- "type" : field_type ,
606+ "type" : clean_type ,
545607 "required" : required and not is_optional ,
546- "default" : default
608+ "default" : default ,
609+ "optional" : is_optional
547610 }
548611
549612 return fields
0 commit comments