Skip to content

Commit 1384119

Browse files
authored
Allow complex output types (#218)
* Allow complex output types - Removes an assertion during schema introspection that prevented list types as a field. - Adds a test verifying that excessively complex output types work correctly * I think moving out of the bad_predictors dir stops evaluating ERROR? * add schema * go e2e test * fix encoding * regenerate type stubs, fix command in CI check error
1 parent 1e3ff7f commit 1384119

File tree

10 files changed

+689
-24
lines changed

10 files changed

+689
-24
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
echo "Generated stubs differ from committed versions:"
6363
git diff python/
6464
echo ""
65-
echo "Run 'script/check.sh stubs' locally and commit the updated .pyi files"
65+
echo "Run 'script/check.sh python' locally and commit the updated .pyi files"
6666
exit 1
6767
fi
6868
echo "All stub files are up to date"

internal/tests/harness_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,17 @@ func writeCogConfig(t *testing.T, tempDir, predictorClass string, concurrencyMax
506506
// FIXME: this is a hack to provide compatibility with the `cog_test` test harness while we migrate to in-process testing.
507507
func linkPythonModule(t *testing.T, basePath, tempDir, module string) {
508508
t.Helper()
509+
510+
// Try runners directory first (for backward compatibility)
509511
runnersPath := path.Join(basePath, "python", "tests", "runners")
510512
srcPath := path.Join(runnersPath, fmt.Sprintf("%s.py", module))
513+
514+
// If not found in runners, try schemas directory
515+
if _, err := os.Stat(srcPath); os.IsNotExist(err) {
516+
schemasPath := path.Join(basePath, "python", "tests", "schemas")
517+
srcPath = path.Join(schemasPath, fmt.Sprintf("%s.py", module))
518+
}
519+
511520
dstPath := path.Join(tempDir, "predict.py")
512521

513522
// Debug logging

internal/tests/output_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,102 @@ func TestPredictionOutputSucceeded(t *testing.T) {
5151
}
5252
assert.Equal(t, expectedOutput, predictionResponse.Output)
5353
}
54+
55+
func TestComplexOutputTypes(t *testing.T) {
56+
t.Parallel()
57+
if *legacyCog {
58+
t.Skip("legacy Cog does not support complex output types")
59+
}
60+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
61+
procedureMode: false,
62+
explicitShutdown: true,
63+
uploadURL: "",
64+
module: "output_complex_types",
65+
predictorClass: "Predictor",
66+
})
67+
waitForSetupComplete(t, runtimeServer, runner.StatusReady, runner.SetupSucceeded)
68+
69+
input := map[string]any{"s": "test"}
70+
req := httpPredictionRequest(t, runtimeServer, runner.PredictionRequest{Input: input})
71+
resp, err := http.DefaultClient.Do(req)
72+
require.NoError(t, err)
73+
defer resp.Body.Close()
74+
assert.Equal(t, http.StatusOK, resp.StatusCode)
75+
body, err := io.ReadAll(resp.Body)
76+
require.NoError(t, err)
77+
var predictionResponse server.PredictionResponse
78+
err = json.Unmarshal(body, &predictionResponse)
79+
require.NoError(t, err)
80+
81+
// Create expected output using JSON round-trip to match server serialization
82+
expectedOutputs := []map[string]any{
83+
{
84+
"strings": []string{"hello", "world"},
85+
"numbers": []int{1, 2, 3},
86+
"single_item": map[string]any{
87+
"name": "item1",
88+
"value": 42,
89+
},
90+
"items": []map[string]any{
91+
{"name": "item1", "value": 42},
92+
{"name": "item2", "value": 84},
93+
},
94+
"container": map[string]any{
95+
"items": []map[string]any{
96+
{"name": "item1", "value": 42},
97+
{"name": "item2", "value": 84},
98+
},
99+
"tags": []string{"tag1", "tag2"},
100+
"nested": map[string]any{
101+
"item": map[string]any{"name": "item1", "value": 42},
102+
"description": "nested description",
103+
},
104+
"optional_list": []string{"opt1", "opt2"},
105+
"count": 2,
106+
},
107+
"nested_items": []map[string]any{
108+
{
109+
"item": map[string]any{"name": "item1", "value": 42},
110+
"description": "nested description",
111+
},
112+
},
113+
},
114+
{
115+
"strings": []string{"foo", "bar"},
116+
"numbers": []int{4, 5, 6},
117+
"single_item": map[string]any{
118+
"name": "item2",
119+
"value": 84,
120+
},
121+
"items": []map[string]any{
122+
{"name": "item2", "value": 84},
123+
},
124+
"container": map[string]any{
125+
"items": []map[string]any{
126+
{"name": "item1", "value": 42},
127+
{"name": "item2", "value": 84},
128+
},
129+
"tags": []string{"tag1", "tag2"},
130+
"nested": map[string]any{
131+
"item": map[string]any{"name": "item1", "value": 42},
132+
"description": "nested description",
133+
},
134+
"optional_list": []string{"opt1", "opt2"},
135+
"count": 2,
136+
},
137+
"nested_items": []map[string]any{
138+
{
139+
"item": map[string]any{"name": "item1", "value": 42},
140+
"description": "nested description",
141+
},
142+
},
143+
},
144+
}
145+
expectedJSON, err := json.Marshal(expectedOutputs)
146+
require.NoError(t, err)
147+
var expectedOutput []any
148+
err = json.Unmarshal(expectedJSON, &expectedOutput)
149+
require.NoError(t, err)
150+
assert.Equal(t, expectedOutput, predictionResponse.Output)
151+
assert.Equal(t, runner.PredictionSucceeded, predictionResponse.Status)
152+
}

python/coglet/inspector.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,6 @@ def _output_adt(tpe: type) -> adt.Output:
187187
fields = {}
188188
for name, t in tpe.__annotations__.items():
189189
ft = adt.FieldType.from_type(t)
190-
assert ft.repetition is not adt.Repetition.REPEATED, (
191-
f'output field must not be list: {name}: {ft.python_type()}'
192-
)
193190
fields[name] = ft
194191
return adt.Output(kind=adt.Kind.OBJECT, fields=fields)
195192

python/coglet/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def output_json(obj):
1919
elif tpe is api.Secret:
2020
# Encode Secret('foobar') as '**********'
2121
return str(obj)
22+
elif hasattr(obj, '__dict__') and hasattr(obj, '__dataclass_fields__'):
23+
# Handle dataclass objects (including BaseModel)
24+
return {field: getattr(obj, field) for field in obj.__dataclass_fields__}
2225
else:
2326
raise TypeError(f'Object of type {tpe} is not JSON serializable')
2427

python/coglet/util.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This type stub file was generated by pyright.
55
def now_iso() -> str:
66
...
77

8-
def output_json(obj): # -> str:
8+
def output_json(obj): # -> str | dict[Any, Any]:
99
...
1010

1111
def schema_json(obj): # -> str:

python/tests/bad_predictors/output_list_field.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)