Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FileContents,
HandshakeType,
ListTypeExpr,
LiteralType,
LiteralTypeExpr,
ModuleName,
NoneTypeExpr,
Expand All @@ -33,6 +34,7 @@
TypeName,
UnionTypeExpr,
extract_inner_type,
normalize_special_chars,
render_literal_type,
render_type_expr,
)
Expand Down Expand Up @@ -396,9 +398,12 @@ def {_field_name}(
case NoneTypeExpr():
typeddict_encoder.append("None")
case other:
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = (
other
)
_o2: (
DictTypeExpr
| OpenUnionTypeExpr
| UnionTypeExpr
| LiteralType
) = other
raise ValueError(f"What does it mean to have {_o2} here?")
if permit_unknown_members:
union = _make_open_union_type_expr(any_of)
Expand Down Expand Up @@ -491,7 +496,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
return (NoneTypeExpr(), [], [], set())
elif type.type == "Date":
typeddict_encoder.append("TODO: dstewart")
return (TypeName("datetime.datetime"), [], [], set())
return (LiteralType("datetime.datetime"), [], [], set())
elif type.type == "array" and type.items:
type_name, module_info, type_chunks, encoder_names = encode_type(
type.items,
Expand Down Expand Up @@ -692,8 +697,21 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
)
)
else:
specialized_name = normalize_special_chars(name)
effective_name = name
extras = ""
if name != specialized_name:
if base_model != "BaseModel":
# TODO: alias support for TypedDict
raise ValueError(
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
)
# Pydantic doesn't allow leading underscores in field names
effective_name = specialized_name.lstrip("_")
extras = f" = Field(serialization_alias={repr(name)})"

current_chunks.append(
f" {name}: {render_type_expr(type_name)}"
f" {effective_name}: {render_type_expr(type_name)}{extras}"
)
typeddict_encoder.append(",")
typeddict_encoder.append("}")
Expand Down
31 changes: 30 additions & 1 deletion src/replit_river/codegen/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import NewType, assert_never, cast

SPECIAL_CHARS = [".", "-", ":", "/", "@", " ", "$", "!", "?", "=", "&", "|", "~", "`"]

ModuleName = NewType("ModuleName", str)
ClassName = NewType("ClassName", str)
FileContents = NewType("FileContents", str)
Expand All @@ -23,6 +25,20 @@ def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class LiteralType:
value: str

def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, LiteralType) and other.value == self.value

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class NoneTypeExpr:
def __str__(self) -> str:
Expand Down Expand Up @@ -111,6 +127,7 @@ def __lt__(self, other: object) -> bool:

TypeExpression = (
TypeName
| LiteralType
| NoneTypeExpr
| DictTypeExpr
| ListTypeExpr
Expand Down Expand Up @@ -145,6 +162,12 @@ def work(
raise ValueError("Incoherent state when trying to flatten unions")


def normalize_special_chars(value: str) -> str:
for char in SPECIAL_CHARS:
value = value.replace(char, "_")
return value


def render_type_expr(value: TypeExpression) -> str:
match _flatten_nested_unions(value):
case DictTypeExpr(nested):
Expand Down Expand Up @@ -192,7 +215,9 @@ def render_type_expr(value: TypeExpression) -> str:
"]"
)
case TypeName(name):
return name
return normalize_special_chars(name)
case LiteralType(literal_value):
return literal_value
case NoneTypeExpr():
return "None"
case other:
Expand Down Expand Up @@ -223,6 +248,10 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
)
case TypeName(name):
return TypeName(name)
case LiteralType(name):
raise ValueError(
f"Attempting to extract from a literal type: {repr(value)}"
)
case NoneTypeExpr():
raise ValueError(
f"Attempting to extract from a literal 'None': {repr(value)}",
Expand Down
Loading