11package cli
22
33import (
4- // blank import for embeds
5- _ "embed"
4+ "embed"
65 "fmt"
76 "os"
87 "path"
@@ -13,20 +12,11 @@ import (
1312 "github.com/replicate/cog/pkg/util/files"
1413)
1514
16- //go:embed init-templates/.dockerignore
17- var dockerignoreContent []byte
18-
19- //go:embed init-templates/cog.yaml
20- var cogYamlContent []byte
21-
22- //go:embed init-templates/predict.py
23- var predictPyContent []byte
24-
25- //go:embed init-templates/.github/workflows/push.yaml
26- var actionsWorkflowContent []byte
27-
28- //go:embed init-templates/requirements.txt
29- var requirementsTxtContent []byte
15+ var (
16+ //go:embed init-templates/**/*
17+ initTemplates embed.FS
18+ pipelineTemplate bool
19+ )
3020
3121func newInitCommand () * cobra.Command {
3222 var cmd = & cobra.Command {
@@ -37,6 +27,7 @@ func newInitCommand() *cobra.Command {
3727 Args : cobra .MaximumNArgs (0 ),
3828 }
3929
30+ addPipelineInit (cmd )
4031 return cmd
4132}
4233
@@ -48,40 +39,96 @@ func initCommand(cmd *cobra.Command, args []string) error {
4839 return err
4940 }
5041
51- fileContentMap := map [string ][]byte {
52- "cog.yaml" : cogYamlContent ,
53- "predict.py" : predictPyContent ,
54- ".dockerignore" : dockerignoreContent ,
55- ".github/workflows/push.yaml" : actionsWorkflowContent ,
56- "requirements.txt" : requirementsTxtContent ,
42+ initTemplate := "base"
43+ if pipelineTemplate {
44+ initTemplate = "pipeline"
5745 }
5846
59- for filename , content := range fileContentMap {
60- filePath := path .Join (cwd , filename )
61- fileExists , err := files .Exists (filePath )
62- if err != nil {
47+ // Discover all files in the embedded template directory
48+ templateDir := path .Join ("init-templates" , initTemplate )
49+ entries , err := initTemplates .ReadDir (templateDir )
50+ if err != nil {
51+ return fmt .Errorf ("Error reading template directory: %w" , err )
52+ }
53+
54+ for _ , entry := range entries {
55+ if entry .IsDir () {
56+ // Recursively process subdirectories
57+ if err := processTemplateDirectory (initTemplates , templateDir , entry .Name (), cwd ); err != nil {
58+ return err
59+ }
60+ continue
61+ }
62+
63+ // Process individual files
64+ if err := processTemplateFile (initTemplates , templateDir , entry .Name (), cwd ); err != nil {
6365 return err
6466 }
67+ }
68+
69+ console .Infof ("\n Done! For next steps, check out the docs at https://cog.run/getting-started" )
70+
71+ return nil
72+ }
73+
74+ func processTemplateDirectory (fs embed.FS , templateDir , subDir , cwd string ) error {
75+ subDirPath := path .Join (templateDir , subDir )
76+ entries , err := fs .ReadDir (subDirPath )
77+ if err != nil {
78+ return fmt .Errorf ("Error reading subdirectory %s: %w" , subDirPath , err )
79+ }
6580
66- if fileExists {
67- console .Infof ("Skipped existing %s" , filename )
81+ for _ , entry := range entries {
82+ if entry .IsDir () {
83+ // Recursively process nested subdirectories
84+ if err := processTemplateDirectory (fs , subDirPath , entry .Name (), cwd ); err != nil {
85+ return err
86+ }
6887 continue
6988 }
7089
71- dirPath := path . Dir ( filePath )
72- err = os . MkdirAll ( dirPath , os . ModePerm )
73- if err != nil {
74- return fmt . Errorf ( "Error creating directory %s: %w" , dirPath , err )
90+ // Process files in subdirectories
91+ relativePath := path . Join ( subDir , entry . Name () )
92+ if err := processTemplateFile ( fs , templateDir , relativePath , cwd ); err != nil {
93+ return err
7594 }
95+ }
7696
77- err = os .WriteFile (filePath , content , 0o644 )
78- if err != nil {
79- return fmt .Errorf ("Error writing %s: %w" , filePath , err )
80- }
81- console .Infof ("✅ Created %s" , filePath )
97+ return nil
98+ }
99+
100+ func processTemplateFile (fs embed.FS , templateDir , filename , cwd string ) error {
101+ filePath := path .Join (cwd , filename )
102+ fileExists , err := files .Exists (filePath )
103+ if err != nil {
104+ return fmt .Errorf ("Error checking if %s exists: %w" , filePath , err )
82105 }
83106
84- console .Infof ("\n Done! For next steps, check out the docs at https://cog.run/getting-started" )
107+ if fileExists {
108+ console .Infof ("Skipped existing %s" , filename )
109+ return nil
110+ }
85111
112+ dirPath := path .Dir (filePath )
113+ if err := os .MkdirAll (dirPath , os .ModePerm ); err != nil {
114+ return fmt .Errorf ("Error creating directory %s: %w" , dirPath , err )
115+ }
116+
117+ content , err := fs .ReadFile (path .Join (templateDir , filename ))
118+ if err != nil {
119+ return fmt .Errorf ("Error reading %s: %w" , filename , err )
120+ }
121+
122+ if err := os .WriteFile (filePath , content , 0o644 ); err != nil {
123+ return fmt .Errorf ("Error writing %s: %w" , filePath , err )
124+ }
125+
126+ console .Infof ("✅ Created %s" , filePath )
86127 return nil
87128}
129+
130+ func addPipelineInit (cmd * cobra.Command ) {
131+ const pipeline = "x-pipeline"
132+ cmd .Flags ().BoolVar (& pipelineTemplate , pipeline , false , "Initialize a pipeline template" )
133+ _ = cmd .Flags ().MarkHidden (pipeline )
134+ }
0 commit comments