diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index ea4dec85..d4e69ddf 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -646,42 +646,24 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: case LiteralTypeExpr(literal_value): field_value = repr(literal_value) if name not in type.required: + type_name = UnionTypeExpr( + [ + type_name, + NoneTypeExpr(), + ] + ) value = "" if base_model != "TypedDict": - value = dedent( - f"""\ - = Field( - default=None, - alias={repr(name)}, # type: ignore - ) - """ - ) - current_chunks.append( - f" kind: { - render_type_expr( - UnionTypeExpr( - [ - type_name, - NoneTypeExpr(), - ] - ) - ) - }{value}" - ) + value = f"= {repr(None)}" else: value = "" if base_model != "TypedDict": - value = dedent( - f"""\ - = Field( - {field_value}, - alias={repr(name)}, # type: ignore - ) - """ - ) - current_chunks.append( - f" kind: {render_type_expr(type_name)}{value}" - ) + value = f"= {field_value}" + current_chunks.append( + f" kind: Annotated[{render_type_expr(type_name)}, Field(alias={ + repr(name) + })]{value}" + ) else: if name not in type.required: if base_model == "TypedDict": diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 9d74699a..4e1243a3 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -38,7 +38,7 @@ def encode_NeedsenumobjectInputOneOf_in_first( class NeedsenumobjectInputOneOf_in_first(TypedDict): - kind: Literal["in_first"] + kind: Annotated[Literal["in_first"], Field(alias="$kind")] value: str @@ -58,7 +58,7 @@ def encode_NeedsenumobjectInputOneOf_in_second( class NeedsenumobjectInputOneOf_in_second(TypedDict): - kind: Literal["in_second"] + kind: Annotated[Literal["in_second"], Field(alias="$kind")] bleep: int @@ -83,20 +83,12 @@ def encode_NeedsenumobjectInput( class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): - kind: Literal["out_first"] = Field( - "out_first", - alias="$kind", # type: ignore - ) - + kind: Annotated[Literal["out_first"], Field(alias="$kind")] = "out_first" foo: int class NeedsenumobjectOutputFooOneOf_out_second(BaseModel): - kind: Literal["out_second"] = Field( - "out_second", - alias="$kind", # type: ignore - ) - + kind: Annotated[Literal["out_second"], Field(alias="$kind")] = "out_second" bar: int diff --git a/tests/codegen/snapshot/test_enum.py b/tests/codegen/snapshot/test_enum.py index d9953dee..e45509d2 100644 --- a/tests/codegen/snapshot/test_enum.py +++ b/tests/codegen/snapshot/test_enum.py @@ -1,4 +1,5 @@ import importlib +import json from io import StringIO from pytest_snapshot.plugin import Snapshot @@ -203,3 +204,46 @@ def test_unknown_enum(snapshot: Snapshot) -> None: x = NeedsenumErrorsTypeAdapter.validate_python(error) assert x.code == error["code"] assert x.message == error["message"] + + +def test_unknown_enum_field_aliases(snapshot: Snapshot) -> None: + validate_codegen( + snapshot=snapshot, + read_schema=lambda: StringIO(test_unknown_enum_schema), + target_path="test_unknown_enum", + client_name="foo", + ) + + import tests.codegen.snapshot.snapshots.test_unknown_enum + + importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) + from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnumObject import ( # noqa + NeedsenumobjectOutputTypeAdapter, + NeedsenumobjectOutput, + NeedsenumobjectOutputFooOneOf_out_first, + ) + + initial = NeedsenumobjectOutput(foo=NeedsenumobjectOutputFooOneOf_out_first(foo=5)) + result = NeedsenumobjectOutputTypeAdapter.dump_json( + initial, + by_alias=True, + ) + + obj = json.loads(result) + + # Make sure we are testing what we think we are testing + assert "foo" in obj + + # We must not include the un-aliased field name + assert "kind" not in obj["foo"] + + # We must include the aliased field name + assert "$kind" in obj["foo"] + + # ... and finally that the values are what we think they should be + assert obj["foo"]["$kind"] == "out_first" + assert obj["foo"]["foo"] == 5 + + # And one more sanity check for the decoder + decoded = NeedsenumobjectOutputTypeAdapter.validate_json(result) + assert decoded == initial