|
1 | 1 | import json |
2 | 2 | import re |
3 | | -from dataclasses import dataclass |
4 | 3 | from pathlib import Path |
5 | | -from textwrap import dedent, indent |
| 4 | +from textwrap import dedent |
6 | 5 | from typing import ( |
7 | 6 | Any, |
8 | 7 | Dict, |
9 | 8 | List, |
10 | 9 | Literal, |
11 | | - NewType, |
12 | 10 | Optional, |
13 | 11 | OrderedDict, |
14 | 12 | Sequence, |
|
21 | 19 | import black |
22 | 20 | from pydantic import BaseModel, Field, RootModel |
23 | 21 |
|
24 | | -TypeName = NewType("TypeName", str) |
25 | | -ModuleName = NewType("ModuleName", str) |
26 | | -ClassName = NewType("ClassName", str) |
27 | | -FileContents = NewType("FileContents", str) |
28 | | -HandshakeType = NewType("HandshakeType", str) |
29 | | - |
30 | | -RenderedPath = NewType("RenderedPath", str) |
31 | | - |
32 | | - |
33 | | -@dataclass |
34 | | -class DictTypeExpr: |
35 | | - nested: "TypeExpression" |
36 | | - |
37 | | - |
38 | | -@dataclass |
39 | | -class ListTypeExpr: |
40 | | - nested: "TypeExpression" |
41 | | - |
42 | | - |
43 | | -@dataclass |
44 | | -class LiteralTypeExpr: |
45 | | - nested: int | str |
46 | | - |
47 | | - |
48 | | -@dataclass |
49 | | -class UnionTypeExpr: |
50 | | - nested: list["TypeExpression"] |
51 | | - |
52 | | - |
53 | | -TypeExpression = ( |
54 | | - TypeName | DictTypeExpr | ListTypeExpr | LiteralTypeExpr | UnionTypeExpr |
| 22 | +from replit_river.codegen.format import reindent |
| 23 | +from replit_river.codegen.typing import ( |
| 24 | + ClassName, |
| 25 | + DictTypeExpr, |
| 26 | + FileContents, |
| 27 | + HandshakeType, |
| 28 | + ListTypeExpr, |
| 29 | + LiteralTypeExpr, |
| 30 | + ModuleName, |
| 31 | + RenderedPath, |
| 32 | + TypeExpression, |
| 33 | + TypeName, |
| 34 | + UnionTypeExpr, |
| 35 | + ensure_literal_type, |
| 36 | + extract_inner_type, |
| 37 | + render_type_expr, |
55 | 38 | ) |
56 | 39 |
|
57 | | - |
58 | | -def render_type_expr(value: TypeExpression) -> str: |
59 | | - match value: |
60 | | - case DictTypeExpr(nested): |
61 | | - return f"dict[str, {render_type_expr(nested)}]" |
62 | | - case ListTypeExpr(nested): |
63 | | - return f"list[{render_type_expr(nested)}]" |
64 | | - case LiteralTypeExpr(inner): |
65 | | - return f"Literal[{repr(inner)}]" |
66 | | - case UnionTypeExpr(inner): |
67 | | - return " | ".join(render_type_expr(x) for x in inner) |
68 | | - case other: |
69 | | - return other |
70 | | - |
71 | | - |
72 | | -def extract_inner_type(value: TypeExpression) -> TypeName: |
73 | | - match value: |
74 | | - case DictTypeExpr(nested): |
75 | | - return extract_inner_type(nested) |
76 | | - case ListTypeExpr(nested): |
77 | | - return extract_inner_type(nested) |
78 | | - case LiteralTypeExpr(_): |
79 | | - raise ValueError(f"Unexpected literal type: {value}") |
80 | | - case UnionTypeExpr(_): |
81 | | - raise ValueError( |
82 | | - f"Attempting to extract from a union, currently not possible: {value}" |
83 | | - ) |
84 | | - case other: |
85 | | - return other |
86 | | - |
87 | | - |
88 | | -def ensure_literal_type(value: TypeExpression) -> TypeName: |
89 | | - match value: |
90 | | - case DictTypeExpr(_): |
91 | | - raise ValueError( |
92 | | - f"Unexpected expression when expecting a type name: {value}" |
93 | | - ) |
94 | | - case ListTypeExpr(_): |
95 | | - raise ValueError( |
96 | | - f"Unexpected expression when expecting a type name: {value}" |
97 | | - ) |
98 | | - case LiteralTypeExpr(_): |
99 | | - raise ValueError( |
100 | | - f"Unexpected expression when expecting a type name: {value}" |
101 | | - ) |
102 | | - case UnionTypeExpr(_): |
103 | | - raise ValueError( |
104 | | - f"Unexpected expression when expecting a type name: {value}" |
105 | | - ) |
106 | | - case other: |
107 | | - return other |
108 | | - |
109 | | - |
110 | 40 | _NON_ALNUM_RE = re.compile(r"[^a-zA-Z0-9_]+") |
111 | 41 |
|
112 | 42 | # Literal is here because HandshakeType can be Literal[None] |
@@ -214,14 +144,6 @@ class RiverSchema(BaseModel): |
214 | 144 | RiverSchemaFile = RootModel[RiverSchema] |
215 | 145 |
|
216 | 146 |
|
217 | | -def reindent(prefix: str, code: str) -> str: |
218 | | - """ |
219 | | - Take an arbitrarily indented code block, dedent to the lowest common |
220 | | - indent level and then reindent based on the supplied prefix |
221 | | - """ |
222 | | - return indent(dedent(code.rstrip()), prefix) |
223 | | - |
224 | | - |
225 | 147 | def is_literal(tpe: RiverType) -> bool: |
226 | 148 | if isinstance(tpe, RiverUnionType): |
227 | 149 | return all(is_literal(t) for t in tpe.anyOf) |
|
0 commit comments