Skip to content

Commit 138ec81

Browse files
authored
feat: support cog init with pipeline template (#2437)
* feat: support cog init with pipeline template Signed-off-by: Mark Phelps <[email protected]> * chore: simplify embeds Signed-off-by: Mark Phelps <[email protected]> * chore: rm if/else logic Signed-off-by: Mark Phelps <[email protected]> * chore: rm requirements.txt in template Signed-off-by: Mark Phelps <[email protected]> * chore: rm requirements.txt in cog.yaml Signed-off-by: Mark Phelps <[email protected]> --------- Signed-off-by: Mark Phelps <[email protected]>
1 parent b16c1b0 commit 138ec81

File tree

10 files changed

+166
-43
lines changed

10 files changed

+166
-43
lines changed
File renamed without changes.

pkg/cli/init-templates/.github/workflows/push.yaml renamed to pkg/cli/init-templates/base/.github/workflows/push.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ jobs:
1616
push_to_replicate:
1717
name: Push to Replicate
1818

19-
# If your model is large, the default GitHub Actions runner may not
20-
# have enough disk space. If you need more space you can set up a
19+
# If your model is large, the default GitHub Actions runner may not
20+
# have enough disk space. If you need more space you can set up a
2121
# bigger runner on GitHub.
2222
runs-on: ubuntu-latest
2323

@@ -37,16 +37,16 @@ jobs:
3737
- name: Setup Cog
3838
uses: replicate/setup-cog@v2
3939
with:
40-
# If you add a CI auth token to your GitHub repository secrets,
40+
# If you add a CI auth token to your GitHub repository secrets,
4141
# the action will authenticate with Replicate automatically so you
4242
# can push your model without needing to pass in a token.
43-
#
43+
#
4444
# To genereate a CLI auth token, run `cog login` or visit this page
4545
# in your browser: https://replicate.com/account/api-token
4646
token: ${{ secrets.REPLICATE_CLI_AUTH_TOKEN }}
4747

4848
# If you trigger the workflow manually, you can specify the model name.
49-
# If you leave it blank (or if the workflow is triggered by a push), the
49+
# If you leave it blank (or if the workflow is triggered by a push), the
5050
# model name will be derived from the `image` value in cog.yaml.
5151
- name: Push to Replicate
5252
run: |
File renamed without changes.
File renamed without changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# The .dockerignore file excludes files from the container build process.
2+
#
3+
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
4+
5+
# Exclude Git files
6+
**/.git
7+
**/.github
8+
**/.gitignore
9+
10+
# Exclude Python tooling
11+
.python-version
12+
13+
# Exclude Python cache files
14+
__pycache__
15+
.mypy_cache
16+
.pytest_cache
17+
.ruff_cache
18+
19+
# Exclude Python virtual environment
20+
/venv
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Configuration for Cog ⚙️
2+
# Reference: https://cog.run/yaml
3+
4+
build:
5+
# a list of ubuntu apt packages to install
6+
# system_packages:
7+
# - "libgl1-mesa-glx"
8+
# - "libglib2.0-0"
9+
10+
# commands run after the environment is setup
11+
# run:
12+
# - "echo env is ready!"
13+
# - "echo another command if needed"
14+
15+
# main.py defines the pipeline
16+
predict: "main.py:run"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Prediction interface for Cog ⚙️
2+
# https://cog.run/python
3+
4+
from cog import Path, Input
5+
import replicate
6+
7+
flux_schnell = replicate.use("black-forest-labs/flux-schnell")
8+
claude = replicate.use("anthropic/claude-3.5-haiku")
9+
10+
def run(
11+
prompt: str = Input(description="Describe the image to generate"),
12+
seed: int = Input(description="A seed", default=0)
13+
) -> Path:
14+
detailed_prompt = claude(prompt=f"""
15+
Generate a detailed prompt for a generative image model that will
16+
generate a high quality dynamic image based on the following
17+
theme: {prompt}
18+
""")
19+
output_paths = flux_schnell(prompt=detailed_prompt, seed=seed)
20+
21+
return Path(output_paths[0])

pkg/cli/init.go

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package cli
22

33
import (
4-
// blank import for embeds
5-
_ "embed"
4+
"embed"
65
"fmt"
76
"os"
87
"path"
@@ -13,20 +12,11 @@ import (
1312
"github.com/replicate/cog/pkg/util/files"
1413
)
1514

16-
//go:embed init-templates/.dockerignore
17-
var dockerignoreContent []byte
18-
19-
//go:embed init-templates/cog.yaml
20-
var cogYamlContent []byte
21-
22-
//go:embed init-templates/predict.py
23-
var predictPyContent []byte
24-
25-
//go:embed init-templates/.github/workflows/push.yaml
26-
var actionsWorkflowContent []byte
27-
28-
//go:embed init-templates/requirements.txt
29-
var requirementsTxtContent []byte
15+
var (
16+
//go:embed init-templates/**/*
17+
initTemplates embed.FS
18+
pipelineTemplate bool
19+
)
3020

3121
func newInitCommand() *cobra.Command {
3222
var cmd = &cobra.Command{
@@ -37,6 +27,7 @@ func newInitCommand() *cobra.Command {
3727
Args: cobra.MaximumNArgs(0),
3828
}
3929

30+
addPipelineInit(cmd)
4031
return cmd
4132
}
4233

@@ -48,40 +39,96 @@ func initCommand(cmd *cobra.Command, args []string) error {
4839
return err
4940
}
5041

51-
fileContentMap := map[string][]byte{
52-
"cog.yaml": cogYamlContent,
53-
"predict.py": predictPyContent,
54-
".dockerignore": dockerignoreContent,
55-
".github/workflows/push.yaml": actionsWorkflowContent,
56-
"requirements.txt": requirementsTxtContent,
42+
initTemplate := "base"
43+
if pipelineTemplate {
44+
initTemplate = "pipeline"
5745
}
5846

59-
for filename, content := range fileContentMap {
60-
filePath := path.Join(cwd, filename)
61-
fileExists, err := files.Exists(filePath)
62-
if err != nil {
47+
// Discover all files in the embedded template directory
48+
templateDir := path.Join("init-templates", initTemplate)
49+
entries, err := initTemplates.ReadDir(templateDir)
50+
if err != nil {
51+
return fmt.Errorf("Error reading template directory: %w", err)
52+
}
53+
54+
for _, entry := range entries {
55+
if entry.IsDir() {
56+
// Recursively process subdirectories
57+
if err := processTemplateDirectory(initTemplates, templateDir, entry.Name(), cwd); err != nil {
58+
return err
59+
}
60+
continue
61+
}
62+
63+
// Process individual files
64+
if err := processTemplateFile(initTemplates, templateDir, entry.Name(), cwd); err != nil {
6365
return err
6466
}
67+
}
68+
69+
console.Infof("\nDone! For next steps, check out the docs at https://cog.run/getting-started")
70+
71+
return nil
72+
}
73+
74+
func processTemplateDirectory(fs embed.FS, templateDir, subDir, cwd string) error {
75+
subDirPath := path.Join(templateDir, subDir)
76+
entries, err := fs.ReadDir(subDirPath)
77+
if err != nil {
78+
return fmt.Errorf("Error reading subdirectory %s: %w", subDirPath, err)
79+
}
6580

66-
if fileExists {
67-
console.Infof("Skipped existing %s", filename)
81+
for _, entry := range entries {
82+
if entry.IsDir() {
83+
// Recursively process nested subdirectories
84+
if err := processTemplateDirectory(fs, subDirPath, entry.Name(), cwd); err != nil {
85+
return err
86+
}
6887
continue
6988
}
7089

71-
dirPath := path.Dir(filePath)
72-
err = os.MkdirAll(dirPath, os.ModePerm)
73-
if err != nil {
74-
return fmt.Errorf("Error creating directory %s: %w", dirPath, err)
90+
// Process files in subdirectories
91+
relativePath := path.Join(subDir, entry.Name())
92+
if err := processTemplateFile(fs, templateDir, relativePath, cwd); err != nil {
93+
return err
7594
}
95+
}
7696

77-
err = os.WriteFile(filePath, content, 0o644)
78-
if err != nil {
79-
return fmt.Errorf("Error writing %s: %w", filePath, err)
80-
}
81-
console.Infof("✅ Created %s", filePath)
97+
return nil
98+
}
99+
100+
func processTemplateFile(fs embed.FS, templateDir, filename, cwd string) error {
101+
filePath := path.Join(cwd, filename)
102+
fileExists, err := files.Exists(filePath)
103+
if err != nil {
104+
return fmt.Errorf("Error checking if %s exists: %w", filePath, err)
82105
}
83106

84-
console.Infof("\nDone! For next steps, check out the docs at https://cog.run/getting-started")
107+
if fileExists {
108+
console.Infof("Skipped existing %s", filename)
109+
return nil
110+
}
85111

112+
dirPath := path.Dir(filePath)
113+
if err := os.MkdirAll(dirPath, os.ModePerm); err != nil {
114+
return fmt.Errorf("Error creating directory %s: %w", dirPath, err)
115+
}
116+
117+
content, err := fs.ReadFile(path.Join(templateDir, filename))
118+
if err != nil {
119+
return fmt.Errorf("Error reading %s: %w", filename, err)
120+
}
121+
122+
if err := os.WriteFile(filePath, content, 0o644); err != nil {
123+
return fmt.Errorf("Error writing %s: %w", filePath, err)
124+
}
125+
126+
console.Infof("✅ Created %s", filePath)
86127
return nil
87128
}
129+
130+
func addPipelineInit(cmd *cobra.Command) {
131+
const pipeline = "x-pipeline"
132+
cmd.Flags().BoolVar(&pipelineTemplate, pipeline, false, "Initialize a pipeline template")
133+
_ = cmd.Flags().MarkHidden(pipeline)
134+
}

pkg/cli/init_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@ func TestInit(t *testing.T) {
1919
require.FileExists(t, path.Join(dir, ".dockerignore"))
2020
require.FileExists(t, path.Join(dir, "cog.yaml"))
2121
require.FileExists(t, path.Join(dir, "predict.py"))
22+
require.FileExists(t, path.Join(dir, "requirements.txt"))
23+
}
24+
25+
func TestInitPipeline(t *testing.T) {
26+
dir := t.TempDir()
27+
28+
require.NoError(t, os.Chdir(dir))
29+
30+
pipelineTemplate = true
31+
t.Cleanup(func() {
32+
pipelineTemplate = false
33+
})
34+
35+
err := initCommand(nil, []string{"--x-pipeline"})
36+
require.NoError(t, err)
37+
38+
require.FileExists(t, path.Join(dir, ".dockerignore"))
39+
require.FileExists(t, path.Join(dir, "cog.yaml"))
40+
require.FileExists(t, path.Join(dir, "main.py"))
2241
}
2342

2443
func TestInitSkipExisting(t *testing.T) {

0 commit comments

Comments
 (0)