Skip to content

Commit b9e7214

Browse files
authored
Add integration test for float with cog-runtime (#2486)
* Add integration test for float with cog-runtime * Check if cog-runtime and cog combined can receive a float input from the CLI * Handle number in openapi schema explicitly * Parse a number in the JSON to the inputs * This fixes the CLI issue with cog-runtime * Add support for integers * Explicitly support integer * Add more test cases * Remove test for int with float
1 parent 23b93cc commit b9e7214

File tree

6 files changed

+109
-2
lines changed

6 files changed

+109
-2
lines changed

pkg/predict/input.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"path/filepath"
88
"reflect"
9+
"strconv"
910
"strings"
1011

1112
"github.com/getkin/kin-openapi/openapi3"
@@ -20,6 +21,8 @@ type Input struct {
2021
File *string
2122
Array *[]any
2223
Json *json.RawMessage
24+
Float *float32
25+
Int *int32
2326
}
2427

2528
type Inputs map[string]Input
@@ -51,11 +54,12 @@ func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error)
5154
property, err := propertiesSchemas.JSONLookup(key)
5255
if err == nil {
5356
propertySchema := property.(*openapi3.Schema)
54-
if propertySchema.Type.Is("object") {
57+
switch {
58+
case propertySchema.Type.Is("object"):
5559
encodedVal := json.RawMessage(val)
5660
input[key] = Input{Json: &encodedVal}
5761
continue
58-
} else if propertySchema.Type.Is("array") {
62+
case propertySchema.Type.Is("array"):
5963
var parsed any
6064
err := json.Unmarshal([]byte(val), &parsed)
6165
if err == nil {
@@ -69,6 +73,29 @@ func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error)
6973
var arr = []any{val}
7074
input[key] = Input{Array: &arr}
7175
continue
76+
case propertySchema.Type.Is("number"):
77+
value, err := strconv.ParseInt(val, 10, 32)
78+
if err == nil {
79+
valueInt := int32(value)
80+
input[key] = Input{Int: &valueInt}
81+
continue
82+
} else {
83+
value, err := strconv.ParseFloat(val, 32)
84+
if err != nil {
85+
return input, err
86+
}
87+
float := float32(value)
88+
input[key] = Input{Float: &float}
89+
continue
90+
}
91+
case propertySchema.Type.Is("integer"):
92+
value, err := strconv.ParseInt(val, 10, 32)
93+
if err != nil {
94+
return input, err
95+
}
96+
valueInt := int32(value)
97+
input[key] = Input{Int: &valueInt}
98+
continue
7299
}
73100
}
74101
}
@@ -131,6 +158,10 @@ func (inputs *Inputs) toMap() (map[string]any, error) {
131158
keyVals[key] = dataURLs
132159
case input.Json != nil:
133160
keyVals[key] = *input.Json
161+
case input.Float != nil:
162+
keyVals[key] = *input.Float
163+
case input.Int != nil:
164+
keyVals[key] = *input.Int
134165
}
135166
}
136167
return keyVals, nil
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
build:
2+
python_version: "3.11"
3+
cog_runtime: true
4+
predict: "predict.py:Predictor"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from cog import BasePredictor, Input
2+
3+
4+
class Predictor(BasePredictor):
5+
def predict(
6+
self, num: float = Input(description="Number of things")
7+
) -> float:
8+
return num * 2.0
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
build:
2+
python_version: "3.11"
3+
cog_runtime: true
4+
predict: "predict.py:Predictor"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from cog import BasePredictor, Input
2+
3+
4+
class Predictor(BasePredictor):
5+
def predict(
6+
self, num: int = Input(description="Number of things")
7+
) -> int:
8+
return num * 2

test-integration/test_integration/test_predict.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,3 +814,55 @@ def test_predict_pipeline(cog_binary):
814814
)
815815
assert result.returncode == 0
816816
assert result.stdout == "HELLO TEST\n"
817+
818+
819+
def test_predict_cog_runtime_float(cog_binary):
820+
project_dir = Path(__file__).parent / "fixtures/cog-runtime-float"
821+
result = subprocess.run(
822+
[cog_binary, "predict", "--debug", "-i", "num=10"],
823+
cwd=project_dir,
824+
capture_output=True,
825+
text=True,
826+
timeout=120.0,
827+
)
828+
assert result.returncode == 0
829+
assert result.stdout == "20\n"
830+
831+
832+
def test_predict_cog_runtime_float_negative(cog_binary):
833+
project_dir = Path(__file__).parent / "fixtures/cog-runtime-float"
834+
result = subprocess.run(
835+
[cog_binary, "predict", "--debug", "-i", "num=-10"],
836+
cwd=project_dir,
837+
capture_output=True,
838+
text=True,
839+
timeout=120.0,
840+
)
841+
assert result.returncode == 0
842+
assert result.stdout == "-20\n"
843+
844+
845+
def test_predict_cog_runtime_int(cog_binary):
846+
project_dir = Path(__file__).parent / "fixtures/cog-runtime-int"
847+
result = subprocess.run(
848+
[cog_binary, "predict", "--debug", "-i", "num=10"],
849+
cwd=project_dir,
850+
capture_output=True,
851+
text=True,
852+
timeout=120.0,
853+
)
854+
assert result.returncode == 0
855+
assert result.stdout == "20\n"
856+
857+
858+
def test_predict_cog_runtime_int_negative(cog_binary):
859+
project_dir = Path(__file__).parent / "fixtures/cog-runtime-int"
860+
result = subprocess.run(
861+
[cog_binary, "predict", "--debug", "-i", "num=-10"],
862+
cwd=project_dir,
863+
capture_output=True,
864+
text=True,
865+
timeout=120.0,
866+
)
867+
assert result.returncode == 0
868+
assert result.stdout == "-20\n"

0 commit comments

Comments
 (0)