4
4
5
5
import os
6
6
import re
7
+ import sys
7
8
from pathlib import Path
8
9
from typing import Dict , List , Optional , Union , Any , Type , get_type_hints , get_origin , get_args
9
10
19
20
from fastapi_mcp .converter import convert_endpoint_to_mcp_tool
20
21
from fastapi_mcp .discovery import Endpoint
21
22
23
+ # Check if Python version supports PEP 604 (|) union types
24
+ PY310_OR_HIGHER = sys .version_info >= (3 , 10 )
25
+
22
26
23
27
def generate_mcp_server (
24
28
app : FastAPI ,
@@ -172,33 +176,27 @@ def generate_server_code(
172
176
fields = get_model_fields (model_class )
173
177
for field_name , field_info in fields .items ():
174
178
# Get type name
175
- if hasattr (field_info ["type" ], "__origin__" ) and field_info ["type" ].__origin__ is Union :
176
- args = field_info ["type" ].__args__
177
- if len (args ) == 2 and args [1 ] is type (None ):
178
- type_name = _get_simple_type_name (args [0 ])
179
+ type_name = _get_simple_type_name (field_info ["type" ])
180
+
181
+ # Check if the field is optional
182
+ is_optional = field_info .get ("optional" , False )
183
+
184
+ # Check if the field has a default value
185
+ if field_info ["default" ] is None and field_info ["required" ]:
186
+ # Required field with no default value - use Undefined
187
+ model_lines .append (f" { field_name } : { type_name } = Undefined" )
188
+ elif field_info ["default" ] is None and not field_info ["required" ]:
189
+ # Optional field with None default
190
+ if is_optional :
179
191
model_lines .append (f" { field_name } : Optional[{ type_name } ] = None" )
180
192
else :
181
- type_name = _get_simple_type_name (field_info ["type" ])
182
- model_lines .append (f" { field_name } : { type_name } " )
193
+ model_lines .append (f" { field_name } : { type_name } = None" )
194
+ elif isinstance (field_info ["default" ], str ):
195
+ # String default value
196
+ model_lines .append (f' { field_name } : { type_name } = "{ field_info ["default" ]} "' )
183
197
else :
184
- type_name = _get_simple_type_name (field_info ["type" ])
185
-
186
- # Check if the field has a default value
187
- if field_info ["default" ] is None and field_info ["required" ]:
188
- # Required field with no default value - use Undefined
189
- model_lines .append (f" { field_name } : { type_name } = Undefined" )
190
- elif field_info ["default" ] is None and not field_info ["required" ]:
191
- # Optional field with None default
192
- if type_name .startswith ("Optional[" ):
193
- model_lines .append (f" { field_name } : { type_name } = None" )
194
- else :
195
- model_lines .append (f" { field_name } : Optional[{ type_name } ] = None" )
196
- elif isinstance (field_info ["default" ], str ):
197
- # String default value
198
- model_lines .append (f' { field_name } : { type_name } = "{ field_info ["default" ]} "' )
199
- else :
200
- # Other default value
201
- model_lines .append (f" { field_name } : { type_name } = { field_info ['default' ]} " )
198
+ # Other default value
199
+ model_lines .append (f" { field_name } : { type_name } = { field_info ['default' ]} " )
202
200
203
201
# Add the model definition
204
202
server_code .extend (model_lines )
@@ -469,31 +467,75 @@ def _get_simple_type_name(type_annotation):
469
467
Returns:
470
468
A string representation of the type.
471
469
"""
470
+ # Handle primitive types directly
471
+ if type_annotation is str :
472
+ return "str"
473
+ elif type_annotation is int :
474
+ return "int"
475
+ elif type_annotation is float :
476
+ return "float"
477
+ elif type_annotation is bool :
478
+ return "bool"
479
+ elif type_annotation is list :
480
+ return "List"
481
+ elif type_annotation is dict :
482
+ return "Dict"
483
+ elif type_annotation is Any :
484
+ return "Any"
485
+
486
+ # Handle PEP 604 union types (X | Y) in Python 3.10+
487
+ if PY310_OR_HIGHER :
488
+ if (hasattr (type_annotation , "__or__" ) or
489
+ (hasattr (type_annotation , "__origin__" ) and str (type_annotation .__origin__ ) == "types.UnionType" )):
490
+ args = getattr (type_annotation , "__args__" , [])
491
+
492
+ # Check if this is equivalent to Optional[T]
493
+ if any (arg is type (None ) for arg in args ):
494
+ # Get the non-None type
495
+ non_none_args = [arg for arg in args if arg is not type (None )]
496
+ if len (non_none_args ) == 1 :
497
+ return f"Optional[{ _get_simple_type_name (non_none_args [0 ])} ]"
498
+
499
+ # Regular Union
500
+ arg_strs = [_get_simple_type_name (arg ) for arg in args ]
501
+ return f"Union[{ ', ' .join (arg_strs )} ]"
502
+
472
503
if hasattr (type_annotation , "__origin__" ):
473
504
# Handle generics like List, Dict, etc.
474
505
origin = get_origin (type_annotation )
475
506
args = get_args (type_annotation )
476
507
477
- if origin is list or origin is List :
508
+ if origin is list or str ( origin ). endswith ( "list" ) :
478
509
if args :
479
510
return f"List[{ _get_simple_type_name (args [0 ])} ]"
480
511
return "List"
481
- elif origin is dict or origin is Dict :
512
+ elif origin is dict or str ( origin ). endswith ( "dict" ) :
482
513
if len (args ) == 2 :
483
514
return f"Dict[{ _get_simple_type_name (args [0 ])} , { _get_simple_type_name (args [1 ])} ]"
484
515
return "Dict"
485
- elif origin is Union :
516
+ elif origin is Union or str (origin ).endswith ("Union" ):
517
+ # Check if this is equivalent to Optional[T]
518
+ if len (args ) == 2 and args [1 ] is type (None ): # noqa
519
+ return f"Optional[{ _get_simple_type_name (args [0 ])} ]"
520
+
521
+ # Regular Union
486
522
arg_strs = [_get_simple_type_name (arg ) for arg in args ]
487
523
return f"Union[{ ', ' .join (arg_strs )} ]"
488
524
else :
489
525
# Other generic types
490
- return str (type_annotation ).replace ("typing." , "" )
526
+ try :
527
+ return str (type_annotation ).replace ("typing." , "" )
528
+ except :
529
+ return "Any"
491
530
elif hasattr (type_annotation , "__name__" ):
492
531
# Regular classes
493
532
return type_annotation .__name__
494
533
else :
495
534
# Fallback
496
- return str (type_annotation ).replace ("typing." , "" )
535
+ try :
536
+ return str (type_annotation ).replace ("typing." , "" )
537
+ except :
538
+ return "Any"
497
539
498
540
499
541
def get_model_fields (model_class ):
@@ -531,19 +573,40 @@ def get_model_fields(model_class):
531
573
except :
532
574
pass
533
575
534
- # Check for Optional type
576
+ # Check for Optional type (PEP 604 union types)
535
577
is_optional = False
578
+ clean_type = field_type # Store the cleaned type
579
+
580
+ # Handle PEP 604 union types (X | Y) in Python 3.10+
581
+ if PY310_OR_HIGHER :
582
+ if (hasattr (field_type , "__or__" ) or
583
+ (hasattr (field_type , "__origin__" ) and str (field_type .__origin__ ) == "types.UnionType" )):
584
+ args = getattr (field_type , "__args__" , [])
585
+ if any (arg is type (None ) for arg in args ):
586
+ is_optional = True
587
+ required = False
588
+ # Extract the non-None type
589
+ non_none_args = [arg for arg in args if arg is not type (None )]
590
+ if len (non_none_args ) == 1 :
591
+ clean_type = non_none_args [0 ]
592
+
593
+ # Check for traditional Union with None type
536
594
if hasattr (field_type , "__origin__" ) and field_type .__origin__ is Union :
537
595
args = field_type .__args__
538
- if len ( args ) == 2 and args [ 1 ] is type (None ):
596
+ if any ( arg is type (None ) for arg in args ):
539
597
is_optional = True
540
598
required = False
599
+ # Extract the non-None type
600
+ non_none_args = [arg for arg in args if arg is not type (None )]
601
+ if len (non_none_args ) == 1 :
602
+ clean_type = non_none_args [0 ]
541
603
542
604
# Add the field info
543
605
fields [field_name ] = {
544
- "type" : field_type ,
606
+ "type" : clean_type ,
545
607
"required" : required and not is_optional ,
546
- "default" : default
608
+ "default" : default ,
609
+ "optional" : is_optional
547
610
}
548
611
549
612
return fields
0 commit comments