Skip to content

Commit 5418c4d

Browse files
fix: support Pydantic BaseModel as prediction output type in schema gen (#2785)
The tree-sitter schema parser only recognized cog.BaseModel subclasses. Pydantic BaseModel subclasses failed with 'unsupported type' because: 1. IsBaseModel() only checked module=='cog', not 'pydantic' 2. inheritsFromBaseModel() only handled 'identifier' AST nodes, missing 'attribute' nodes for dotted access (pydantic.BaseModel) 3. parseImportFrom() only handled aliased_import inside import_list, missing top-level aliased imports (from X import Y as Z) Now supports all three import styles: - from pydantic import BaseModel - from pydantic import BaseModel as PydanticBaseModel - import pydantic; class Foo(pydantic.BaseModel) Includes 3 unit tests and 1 integration test.
1 parent 838c013 commit 5418c4d

File tree

4 files changed

+132
-3
lines changed

4 files changed

+132
-3
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Test that Pydantic v2 BaseModel works as prediction output type.
2+
# Coglet's make_encodeable() must call model_dump() to serialize.
3+
4+
# Build the image
5+
cog build -t $TEST_IMAGE
6+
7+
# Predict returns structured Pydantic output
8+
cog predict $TEST_IMAGE -i name=alice -i score=0.95
9+
stdout '"name": "alice"'
10+
stdout '"score": 0.95'
11+
stdout '"tags"'
12+
stdout 'default'
13+
14+
-- cog.yaml --
15+
build:
16+
python_version: "3.12"
17+
python_packages:
18+
- "pydantic>2"
19+
predict: "predict.py:Predictor"
20+
21+
-- predict.py --
22+
from typing import List
23+
24+
from pydantic import BaseModel as PydanticBaseModel
25+
26+
from cog import BasePredictor
27+
28+
29+
class Result(PydanticBaseModel):
30+
name: str
31+
score: float
32+
tags: List[str]
33+
34+
35+
class Predictor(BasePredictor):
36+
def predict(self, name: str, score: float = 0.5) -> Result:
37+
return Result(name=name, score=score, tags=["default"])

pkg/schema/python/parser.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ func parseImportFrom(node *sitter.Node, source []byte, ctx *schema.ImportContext
161161
name := content(child, source)
162162
ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name})
163163
}
164+
case "aliased_import":
165+
// Single aliased import: `from X import name as alias`
166+
origNode := child.ChildByFieldName("name")
167+
aliasNode := child.ChildByFieldName("alias")
168+
orig := ""
169+
if origNode != nil {
170+
orig = content(origNode, source)
171+
}
172+
alias := orig
173+
if aliasNode != nil {
174+
alias = content(aliasNode, source)
175+
}
176+
ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig})
164177
case "import_list":
165178
for _, importChild := range allChildren(child) {
166179
switch importChild.Type() {
@@ -404,11 +417,18 @@ func inheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schem
404417
return false
405418
}
406419
for _, child := range allChildren(supers) {
407-
if child.Type() == "identifier" {
420+
switch child.Type() {
421+
case "identifier":
408422
name := content(child, source)
409423
if imports.IsBaseModel(name) || name == "BaseModel" {
410424
return true
411425
}
426+
case "attribute":
427+
// Handle dotted access: pydantic.BaseModel, cog.BaseModel
428+
text := content(child, source)
429+
if strings.HasSuffix(text, ".BaseModel") {
430+
return true
431+
}
412432
}
413433
}
414434
return false

pkg/schema/python/parser_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,78 @@ class Predictor(BasePredictor):
10021002
require.Equal(t, 0.0, conf.Default.Float)
10031003
}
10041004

1005+
// ---------------------------------------------------------------------------
1006+
// Pydantic BaseModel output
1007+
// ---------------------------------------------------------------------------
1008+
1009+
func TestPydanticBaseModelOutput(t *testing.T) {
1010+
source := `
1011+
from pydantic import BaseModel as PydanticBaseModel
1012+
from cog import BasePredictor
1013+
1014+
class Result(PydanticBaseModel):
1015+
name: str
1016+
score: float
1017+
tags: list[str]
1018+
1019+
class Predictor(BasePredictor):
1020+
def predict(self, name: str) -> Result:
1021+
pass
1022+
`
1023+
info := parse(t, source, "Predictor")
1024+
require.Equal(t, schema.OutputObject, info.Output.Kind)
1025+
require.NotNil(t, info.Output.Fields)
1026+
require.Equal(t, 3, info.Output.Fields.Len())
1027+
1028+
name, ok := info.Output.Fields.Get("name")
1029+
require.True(t, ok)
1030+
require.Equal(t, schema.TypeString, name.FieldType.Primitive)
1031+
1032+
score, ok := info.Output.Fields.Get("score")
1033+
require.True(t, ok)
1034+
require.Equal(t, schema.TypeFloat, score.FieldType.Primitive)
1035+
}
1036+
1037+
func TestPydanticBaseModelDottedOutput(t *testing.T) {
1038+
source := `
1039+
import pydantic
1040+
from cog import BasePredictor
1041+
1042+
class Result(pydantic.BaseModel):
1043+
text: str
1044+
1045+
class Predictor(BasePredictor):
1046+
def predict(self, s: str) -> Result:
1047+
pass
1048+
`
1049+
info := parse(t, source, "Predictor")
1050+
require.Equal(t, schema.OutputObject, info.Output.Kind)
1051+
1052+
text, ok := info.Output.Fields.Get("text")
1053+
require.True(t, ok)
1054+
require.Equal(t, schema.TypeString, text.FieldType.Primitive)
1055+
}
1056+
1057+
func TestPydanticBaseModelDirectImport(t *testing.T) {
1058+
source := `
1059+
from pydantic import BaseModel
1060+
from cog import BasePredictor
1061+
1062+
class Output(BaseModel):
1063+
value: int
1064+
1065+
class Predictor(BasePredictor):
1066+
def predict(self, x: int) -> Output:
1067+
pass
1068+
`
1069+
info := parse(t, source, "Predictor")
1070+
require.Equal(t, schema.OutputObject, info.Output.Kind)
1071+
1072+
val, ok := info.Output.Fields.Get("value")
1073+
require.True(t, ok)
1074+
require.Equal(t, schema.TypeInteger, val.FieldType.Primitive)
1075+
}
1076+
10051077
// ---------------------------------------------------------------------------
10061078
// No-input predictor (only self)
10071079
// ---------------------------------------------------------------------------

pkg/schema/types.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,10 @@ func (ctx *ImportContext) IsTypingType(name string) bool {
340340
return false
341341
}
342342

343-
// IsBaseModel returns true if name resolves to cog.BaseModel.
343+
// IsBaseModel returns true if name resolves to cog.BaseModel or pydantic.BaseModel.
344344
func (ctx *ImportContext) IsBaseModel(name string) bool {
345345
if e, ok := ctx.Names.Get(name); ok {
346-
return e.Module == "cog" && e.Original == "BaseModel"
346+
return (e.Module == "cog" || e.Module == "pydantic" || e.Module == "pydantic.v1") && e.Original == "BaseModel"
347347
}
348348
return false
349349
}

0 commit comments

Comments
 (0)