Skip to content

Commit cafcc18

Browse files
bug/minor oddities (#116)
Why === None of these changes impact production codepaths, so we've presumably been skating by without issue up until now, but let's make sure we don't run into trouble in the future. What changed ============ - `f"x: '{foo}'` -> `f"x: {repr(foo)}"`. If it's a string, it'll be a no-op. If it's not, it'll be an error. - `RiverNotType` has never worked. Rip it out. - Address incorrect `init` vs `input` bindings - The result type of transport messages will be `None` in the case where the whole function also should return `None`, so remove some confusing metaprogramming about optional `return` statements. Test plan ========= _Describe what you did to test this change to a level of detail that allows your reviewer to test it_
1 parent b3afe94 commit cafcc18

File tree

2 files changed

+71
-85
lines changed

2 files changed

+71
-85
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Repl.it
3+
Copyright (c) 2024 Replit
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

replit_river/codegen/client.py

Lines changed: 70 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import Path
55
from textwrap import dedent, indent
66
from typing import (
7-
Any,
87
Dict,
98
List,
109
Literal,
@@ -180,15 +179,7 @@ class RiverIntersectionType(BaseModel):
180179
allOf: List["RiverType"]
181180

182181

183-
class RiverNotType(BaseModel):
184-
"""This is used to represent void / never."""
185-
186-
not_: Any = Field(..., alias="not")
187-
188-
189-
RiverType = Union[
190-
RiverConcreteType, RiverUnionType, RiverNotType, RiverIntersectionType
191-
]
182+
RiverType = Union[RiverConcreteType, RiverUnionType, RiverIntersectionType]
192183

193184

194185
class RiverProcedure(BaseModel):
@@ -239,8 +230,6 @@ def encode_type(
239230
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
240231
encoder_name: Optional[str] = None # defining this up here to placate mypy
241232
chunks: List[FileContents] = []
242-
if isinstance(type, RiverNotType):
243-
return (TypeName("None"), [], [], set())
244233
if isinstance(type, RiverUnionType):
245234
typeddict_encoder = list[str]()
246235
encoder_names: set[TypeName] = set()
@@ -352,7 +341,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
352341
else:
353342
local_discriminator = "FIXME: Ambiguous discriminators"
354343
typeddict_encoder.append(
355-
f" if '{local_discriminator}' in x else "
344+
f" if {repr(local_discriminator)} in x else "
356345
)
357346
typeddict_encoder.pop() # Drop the last ternary
358347
typeddict_encoder.append(")")
@@ -372,8 +361,8 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
372361
typeddict_encoder.append(f"{encoder_name}(x)")
373362
typeddict_encoder.append(
374363
f"""
375-
if x['{discriminator_name}']
376-
== '{discriminator_value}'
364+
if x[{repr(discriminator_name)}]
365+
== {repr(discriminator_value)}
377366
else
378367
""",
379368
)
@@ -393,7 +382,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
393382
[
394383
dedent(
395384
f"""\
396-
{encoder_name}: Callable[['{prefix}'], Any] = (
385+
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
397386
lambda x:
398387
""".rstrip()
399388
)
@@ -450,14 +439,17 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
450439
chunks.append(
451440
FileContents(
452441
"\n".join(
453-
[f"{encoder_name}: Callable[['{prefix}'], Any] = (lambda x: "]
442+
[
443+
f"{encoder_name}: Callable[[{repr(prefix)}], Any] = ("
444+
"lambda x: "
445+
]
454446
+ typeddict_encoder
455447
+ [")"]
456448
)
457449
)
458450
)
459451
return (prefix, in_module, chunks, encoder_names)
460-
if isinstance(type, RiverIntersectionType):
452+
elif isinstance(type, RiverIntersectionType):
461453

462454
def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
463455
if isinstance(tpe, RiverUnionType):
@@ -478,15 +470,17 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
478470
base_model,
479471
in_module,
480472
)
481-
if isinstance(type, RiverConcreteType):
473+
elif isinstance(type, RiverConcreteType):
482474
typeddict_encoder = list[str]()
483475
if type.type is None:
484476
# Handle the case where type is not specified
485477
typeddict_encoder.append("x")
486478
return (TypeName("Any"), [], [], set())
479+
elif type.type == "not":
480+
return (TypeName("None"), [], [], set())
487481
elif type.type == "string":
488482
if type.const:
489-
typeddict_encoder.append(f"'{type.const}'")
483+
typeddict_encoder.append(repr(type.const))
490484
return (LiteralTypeExpr(type.const), [], [], set())
491485
else:
492486
typeddict_encoder.append("x")
@@ -565,48 +559,48 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
565559
name,
566560
prop,
567561
) in sorted(list(type.properties.items()), key=lambda xs: xs[0]):
568-
typeddict_encoder.append(f"'{name}':")
562+
typeddict_encoder.append(f"{repr(name)}:")
569563
type_name, _, contents, _ = encode_type(
570564
prop, TypeName(prefix + name.title()), base_model, in_module
571565
)
572566
encoder_name = None
573567
chunks.extend(contents)
574568
if base_model == "TypedDict":
575-
if isinstance(prop, RiverNotType):
576-
typeddict_encoder.append("'not implemented'")
577-
elif isinstance(prop, RiverUnionType):
569+
if isinstance(prop, RiverUnionType):
578570
encoder_name = TypeName(
579571
f"encode_{ensure_literal_type(type_name)}"
580572
)
581573
encoder_names.add(encoder_name)
582-
typeddict_encoder.append(f"{encoder_name}(x['{name}'])")
574+
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
583575
if name not in type.required:
584-
typeddict_encoder.append(f"if x['{name}'] else None")
576+
typeddict_encoder.append(f"if x[{repr(name)}] else None")
585577
elif isinstance(prop, RiverIntersectionType):
586578
encoder_name = TypeName(
587579
f"encode_{ensure_literal_type(type_name)}"
588580
)
589581
encoder_names.add(encoder_name)
590-
typeddict_encoder.append(f"{encoder_name}(x['{name}'])")
582+
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
591583
elif isinstance(prop, RiverConcreteType):
592584
if name == "$kind":
593585
safe_name = "kind"
594586
else:
595587
safe_name = name
596-
if prop.type == "object" and not prop.patternProperties:
588+
if prop.type == "not":
589+
typeddict_encoder.append("'not implemented'")
590+
elif prop.type == "object" and not prop.patternProperties:
597591
encoder_name = TypeName(
598592
f"encode_{ensure_literal_type(type_name)}"
599593
)
600594
encoder_names.add(encoder_name)
601595
typeddict_encoder.append(
602-
f"{encoder_name}(x['{safe_name}'])"
596+
f"{encoder_name}(x[{repr(safe_name)}])"
603597
)
604598
if name not in prop.required:
605599
typeddict_encoder.append(
606600
dedent(
607601
f"""
608-
if '{safe_name}' in x
609-
and x['{safe_name}'] is not None
602+
if {repr(safe_name)} in x
603+
and x[{repr(safe_name)}] is not None
610604
else None
611605
"""
612606
)
@@ -615,7 +609,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
615609
items = cast(RiverConcreteType, prop).items
616610
assert items, "Somehow items was none"
617611
if is_literal(cast(RiverType, items)):
618-
typeddict_encoder.append(f"x['{name}']")
612+
typeddict_encoder.append(f"x[{repr(name)}]")
619613
else:
620614
match type_name:
621615
case ListTypeExpr(inner_type_name):
@@ -628,16 +622,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
628622
f"""\
629623
[
630624
{encoder_name}(y)
631-
for y in x['{name}']
625+
for y in x[{repr(name)}]
632626
]
633627
""".rstrip()
634628
)
635629
)
636630
else:
637631
if name in prop.required:
638-
typeddict_encoder.append(f"x['{safe_name}']")
632+
typeddict_encoder.append(f"x[{repr(safe_name)}]")
639633
else:
640-
typeddict_encoder.append(f"x.get('{safe_name}')")
634+
typeddict_encoder.append(f"x.get({repr(safe_name)})")
641635

642636
if name == "$kind":
643637
# If the field is a literal, the Python type-checker will complain
@@ -657,7 +651,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
657651
f"""\
658652
= Field(
659653
default=None,
660-
alias='{name}', # type: ignore
654+
alias={repr(name)}, # type: ignore
661655
)
662656
"""
663657
)
@@ -671,7 +665,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
671665
f"""\
672666
= Field(
673667
{field_value},
674-
alias='{name}', # type: ignore
668+
alias={repr(name)}, # type: ignore
675669
)
676670
"""
677671
)
@@ -714,7 +708,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
714708
[
715709
dedent(
716710
f"""\
717-
{encoder_name}: Callable[['{prefix}'], Any] = (
711+
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
718712
lambda {binding}:
719713
"""
720714
)
@@ -781,7 +775,7 @@ def __init__(self, client: river.Client[Any]):
781775
module_names = [ModuleName(name)]
782776
init_type: Optional[TypeExpression] = None
783777
if procedure.init:
784-
init_type, module_info, input_chunks, encoder_names = encode_type(
778+
init_type, module_info, init_chunks, encoder_names = encode_type(
785779
procedure.init,
786780
TypeName(f"{name.title()}Init"),
787781
input_base_class,
@@ -791,7 +785,7 @@ def __init__(self, client: river.Client[Any]):
791785
(
792786
[extract_inner_type(init_type), *encoder_names],
793787
module_info,
794-
input_chunks,
788+
init_chunks,
795789
)
796790
)
797791
input_type, module_info, input_chunks, encoder_names = encode_type(
@@ -856,31 +850,28 @@ def __init__(self, client: river.Client[Any]):
856850

857851
# Init renderer
858852
render_init_method: Optional[str] = None
859-
if input_base_class == "TypedDict" and init_type:
860-
if is_literal(procedure.input):
861-
render_init_method = "lambda x: x"
862-
elif isinstance(
863-
procedure.input, RiverConcreteType
864-
) and procedure.input.type in ["array"]:
865-
match init_type:
866-
case ListTypeExpr(init_type_name):
867-
render_init_method = (
868-
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
869-
)
853+
if init_type and procedure.init is not None:
854+
if input_base_class == "TypedDict":
855+
if is_literal(procedure.init):
856+
render_init_method = "lambda x: x"
857+
elif isinstance(
858+
procedure.init, RiverConcreteType
859+
) and procedure.init.type in ["array"]:
860+
match init_type:
861+
case ListTypeExpr(init_type_name):
862+
render_init_method = (
863+
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
864+
)
865+
else:
866+
render_init_method = f"encode_{ensure_literal_type(init_type)}"
870867
else:
871-
render_init_method = f"encode_{ensure_literal_type(init_type)}"
872-
else:
873-
render_init_method = f"""\
874-
lambda x: TypeAdapter({render_type_expr(input_type)})
875-
.validate_python
876-
"""
877-
if isinstance(
878-
procedure.init, RiverConcreteType
879-
) and procedure.init.type not in ["object", "array"]:
880-
render_init_method = "lambda x: x"
868+
render_init_method = f"""\
869+
lambda x: TypeAdapter({render_type_expr(init_type)})
870+
.validate_python
871+
"""
881872

882873
assert (
883-
render_init_method
874+
init_type is None or render_init_method
884875
), f"Unable to derive the init encoder from: {input_type}"
885876

886877
# Input renderer
@@ -922,10 +913,6 @@ def __init__(self, client: river.Client[Any]):
922913
parse_output_method = "lambda x: None"
923914

924915
if procedure.type == "rpc":
925-
control_flow_keyword = "return "
926-
if output_type == "None":
927-
control_flow_keyword = ""
928-
929916
current_chunks.extend(
930917
[
931918
reindent(
@@ -935,9 +922,9 @@ async def {name}(
935922
self,
936923
input: {render_type_expr(input_type)},
937924
) -> {render_type_expr(output_type)}:
938-
{control_flow_keyword}await self.client.send_rpc(
939-
'{schema_name}',
940-
'{name}',
925+
return await self.client.send_rpc(
926+
{repr(schema_name)},
927+
{repr(name)},
941928
input,
942929
{reindent(" ", render_input_method)},
943930
{reindent(" ", parse_output_method)},
@@ -958,8 +945,8 @@ async def {name}(
958945
input: {render_type_expr(input_type)},
959946
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
960947
return await self.client.send_subscription(
961-
'{schema_name}',
962-
'{name}',
948+
{repr(schema_name)},
949+
{repr(name)},
963950
input,
964951
{reindent(" ", render_input_method)},
965952
{reindent(" ", parse_output_method)},
@@ -970,10 +957,8 @@ async def {name}(
970957
]
971958
)
972959
elif procedure.type == "upload":
973-
control_flow_keyword = "return "
974-
if output_type == "None":
975-
control_flow_keyword = ""
976960
if init_type:
961+
assert render_init_method, "Expected an init renderer!"
977962
current_chunks.extend(
978963
[
979964
reindent(
@@ -984,9 +969,9 @@ async def {name}(
984969
init: {init_type},
985970
inputStream: AsyncIterable[{render_type_expr(input_type)}],
986971
) -> {output_type}:
987-
{control_flow_keyword}await self.client.send_upload(
988-
'{schema_name}',
989-
'{name}',
972+
return await self.client.send_upload(
973+
{repr(schema_name)},
974+
{repr(name)},
990975
init,
991976
inputStream,
992977
{reindent(" ", render_init_method)},
@@ -1008,9 +993,9 @@ async def {name}(
1008993
self,
1009994
inputStream: AsyncIterable[{render_type_expr(input_type)}],
1010995
) -> {render_type_expr(output_or_error_type)}:
1011-
{control_flow_keyword}await self.client.send_upload(
1012-
'{schema_name}',
1013-
'{name}',
996+
return await self.client.send_upload(
997+
{repr(schema_name)},
998+
{repr(name)},
1014999
None,
10151000
inputStream,
10161001
None,
@@ -1024,6 +1009,7 @@ async def {name}(
10241009
)
10251010
elif procedure.type == "stream":
10261011
if init_type:
1012+
assert render_init_method, "Expected an init renderer!"
10271013
current_chunks.extend(
10281014
[
10291015
reindent(
@@ -1035,8 +1021,8 @@ async def {name}(
10351021
inputStream: AsyncIterable[{render_type_expr(input_type)}],
10361022
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
10371023
return await self.client.send_stream(
1038-
'{schema_name}',
1039-
'{name}',
1024+
{repr(schema_name)},
1025+
{repr(name)},
10401026
init,
10411027
inputStream,
10421028
{reindent(" ", render_init_method)},
@@ -1059,8 +1045,8 @@ async def {name}(
10591045
inputStream: AsyncIterable[{render_type_expr(input_type)}],
10601046
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
10611047
return await self.client.send_stream(
1062-
'{schema_name}',
1063-
'{name}',
1048+
{repr(schema_name)},
1049+
{repr(name)},
10641050
None,
10651051
inputStream,
10661052
None,

0 commit comments

Comments
 (0)