Skip to content

Commit 71ad986

Browse files
authored
Handle badly formatted data URIs in cog predict (#2409)
* Currently cog returns data:None;base64 * The data URI parser does not recognise None, and requires data:;base64 * Handle receiving these URIs by removing the None format specified
1 parent e7b086c commit 71ad986

File tree

8 files changed

+111
-57
lines changed

8 files changed

+111
-57
lines changed

pkg/cli/predict.go

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"github.com/getkin/kin-openapi/openapi3"
2020
"github.com/mitchellh/go-homedir"
2121
"github.com/spf13/cobra"
22-
"github.com/vincent-petithory/dataurl"
2322
"golang.org/x/sys/unix"
2423

2524
"github.com/replicate/cog/pkg/config"
@@ -30,6 +29,7 @@ import (
3029
"github.com/replicate/cog/pkg/predict"
3130
"github.com/replicate/cog/pkg/registry"
3231
"github.com/replicate/cog/pkg/util/console"
32+
"github.com/replicate/cog/pkg/util/files"
3333
"github.com/replicate/cog/pkg/util/mime"
3434
)
3535

@@ -443,7 +443,7 @@ func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPat
443443
}
444444

445445
if writeOutputToDisk {
446-
path, err := writeFile(indentedJSON.Bytes(), outputPath)
446+
path, err := files.WriteFile(indentedJSON.Bytes(), outputPath)
447447
if err != nil {
448448
return fmt.Errorf("Failed to write output: %w", err)
449449
}
@@ -485,7 +485,7 @@ func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPat
485485
}
486486

487487
if writeOutputToDisk {
488-
path, err := writeFile([]byte(s), outputPath)
488+
path, err := files.WriteFile([]byte(s), outputPath)
489489
if err != nil {
490490
return fmt.Errorf("Failed to write output: %w", err)
491491
}
@@ -505,7 +505,7 @@ func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPat
505505

506506
// No special handling for needsJSON here.
507507
if writeOutputToDisk {
508-
path, err := writeFile(output, outputPath)
508+
path, err := files.WriteFile(output, outputPath)
509509
if err != nil {
510510
return fmt.Errorf("Failed to write output: %w", err)
511511
}
@@ -585,7 +585,7 @@ func processFileOutputs(output any, schema *openapi3.Schema, destination string)
585585
return nil, fmt.Errorf("Failed to convert prediction output to string: %v", output)
586586
}
587587

588-
path, err := writeDataURLToFile(outputStr, destination)
588+
path, err := files.WriteDataURLToFile(outputStr, destination)
589589
if err != nil {
590590
return nil, fmt.Errorf("Failed to write output: %w", err)
591591
}
@@ -615,58 +615,6 @@ func processFileOutputs(output any, schema *openapi3.Schema, destination string)
615615
return output, nil
616616
}
617617

618-
func writeFile(output []byte, outputPath string) (string, error) {
619-
outputPath, err := homedir.Expand(outputPath)
620-
if err != nil {
621-
return "", err
622-
}
623-
624-
// Write to file
625-
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
626-
if err != nil {
627-
return "", err
628-
}
629-
630-
if _, err := outFile.Write(output); err != nil {
631-
return "", err
632-
}
633-
if err := outFile.Close(); err != nil {
634-
return "", err
635-
}
636-
return outputPath, nil
637-
}
638-
639-
// Writes a data URL to the destination. If no file extension is provided then it
640-
// will be inferred from the data URL mime type and appended.
641-
func writeDataURLToFile(url string, destination string) (string, error) {
642-
dataurlObj, err := dataurl.DecodeString(url)
643-
if err != nil {
644-
return "", fmt.Errorf("Failed to decode data URL: %w", err)
645-
}
646-
output := dataurlObj.Data
647-
648-
ext := path.Ext(destination)
649-
dir := path.Dir(destination)
650-
name := r8_path.TrimExt(path.Base(destination))
651-
652-
// Check if ext is an integer, in which case ignore it...
653-
if r8_path.IsExtInteger(ext) {
654-
ext = ""
655-
name = path.Base(destination)
656-
}
657-
658-
if ext == "" {
659-
ext = mime.ExtensionByType(dataurlObj.ContentType())
660-
}
661-
662-
path, err := writeFile(output, path.Join(dir, name+ext))
663-
if err != nil {
664-
return "", err
665-
}
666-
667-
return path, nil
668-
}
669-
670618
func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error) {
671619
keyVals := map[string][]string{}
672620
for _, input := range inputs {

pkg/util/files/files.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,15 @@ import (
55
"fmt"
66
"io"
77
"os"
8+
"path"
9+
"strings"
810

11+
"github.com/mitchellh/go-homedir"
12+
"github.com/vincent-petithory/dataurl"
913
"golang.org/x/sys/unix"
14+
15+
r8_path "github.com/replicate/cog/pkg/path"
16+
"github.com/replicate/cog/pkg/util/mime"
1017
)
1118

1219
func Exists(path string) (bool, error) {
@@ -82,3 +89,56 @@ func WriteIfDifferent(file, content string) error {
8289
}
8390
return nil
8491
}
92+
93+
func WriteDataURLToFile(url string, destination string) (string, error) {
94+
if strings.HasPrefix(url, "data:None;base64") {
95+
url = strings.Replace(url, "data:None;base64", "data:;base64", 1)
96+
}
97+
dataurlObj, err := dataurl.DecodeString(url)
98+
if err != nil {
99+
return "", fmt.Errorf("Failed to decode data URL: %w", err)
100+
}
101+
output := dataurlObj.Data
102+
103+
ext := path.Ext(destination)
104+
dir := path.Dir(destination)
105+
name := r8_path.TrimExt(path.Base(destination))
106+
107+
// Check if ext is an integer, in which case ignore it...
108+
if r8_path.IsExtInteger(ext) {
109+
ext = ""
110+
name = path.Base(destination)
111+
}
112+
113+
if ext == "" {
114+
ext = mime.ExtensionByType(dataurlObj.ContentType())
115+
}
116+
117+
path, err := WriteFile(output, path.Join(dir, name+ext))
118+
if err != nil {
119+
return "", err
120+
}
121+
122+
return path, nil
123+
}
124+
125+
func WriteFile(output []byte, outputPath string) (string, error) {
126+
outputPath, err := homedir.Expand(outputPath)
127+
if err != nil {
128+
return "", err
129+
}
130+
131+
// Write to file
132+
outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
133+
if err != nil {
134+
return "", err
135+
}
136+
137+
if _, err := outFile.Write(output); err != nil {
138+
return "", err
139+
}
140+
if err := outFile.Close(); err != nil {
141+
return "", err
142+
}
143+
return outputPath, nil
144+
}

pkg/util/files/files_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ func TestIsExecutable(t *testing.T) {
1818
require.NoError(t, os.Chmod(path, 0o744))
1919
require.True(t, IsExecutable(path))
2020
}
21+
22+
func TestWriteBadlyFormattedBase64DataURI(t *testing.T) {
23+
dir := t.TempDir()
24+
path := filepath.Join(dir, "test-file")
25+
_, err := WriteDataURLToFile("data:None;base64,SGVsbG8gVGhlcmU=", path)
26+
require.NoError(t, err)
27+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
build:
2+
python_version: "3.13"
3+
4+
predict: "predict.py:Predictor"
2.04 MB
Binary file not shown.
Binary file not shown.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Prediction interface for Cog ⚙️
2+
# https://github.com/replicate/cog/blob/main/docs/python.md
3+
4+
from cog import BasePredictor, Path
5+
6+
7+
class Predictor(BasePredictor):
8+
def setup(self) -> None:
9+
"""Load the model into memory to make running multiple predictions efficient"""
10+
if not Path("mesh.glb").exists():
11+
raise ValueError("Example file mesh.glb does not exist")
12+
13+
def predict(
14+
self,
15+
) -> Path:
16+
"""Run a single prediction on the model"""
17+
return Path("mesh.glb")

test-integration/test_integration/test_predict.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,3 +769,21 @@ def test_predict_json_input_stdin_dash(cog_binary):
769769
}
770770
"""
771771
)
772+
773+
774+
def test_predict_glb_file(cog_binary):
775+
project_dir = Path(__file__).parent / "fixtures/glb-project"
776+
777+
result = subprocess.run(
778+
[
779+
cog_binary,
780+
"predict",
781+
"--debug",
782+
],
783+
cwd=project_dir,
784+
check=True,
785+
capture_output=True,
786+
text=True,
787+
timeout=120.0,
788+
)
789+
assert result.returncode == 0

0 commit comments

Comments
 (0)