Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion internal/cmd/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package cmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"

Expand Down Expand Up @@ -191,6 +192,10 @@ just print the generated policy.
}

if opts.openPullRequest && opts.interactive {
if err := ensureOrCreatePolicyFork(srctool); err != nil {
return err
}

fmt.Printf(`

sourcetool is about to perform the following actions on your behalf:
Expand Down Expand Up @@ -230,7 +235,7 @@ open the pull request from there.

if opts.openPullRequest && pr != nil {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "Opened pull request: https://github.com/%s/pulls/%d\n\n", pr.Repo.Path, pr.Number)
fmt.Fprintf(os.Stderr, "pull request open: https://github.com/%s/pull/%d\n\n", pr.Repo.Path, pr.Number)
}

return nil
Expand Down Expand Up @@ -260,3 +265,37 @@ func displayPolicy(opts repoOptions, pcy *policy.RepoPolicy) error {
fmt.Println()
return nil
}

// ensureOrCreatePolicyFork checks the user has a fork of the policy repo.
// In case they do not, asks to create a fork in their GitHub account.
func ensureOrCreatePolicyFork(srctool *sourcetool.Tool) error {
found, err := srctool.CheckPolicyRepoFork()
if err != nil {
return fmt.Errorf("checking for policy repo fork: %w", err)
}
if found {
return nil
}

fmt.Println()
fmt.Println()
fmt.Printf("%s sourcetool could not find a fork of the community source\n", w("Note:"))
fmt.Printf("policy repository (%s). We need it to\n", srctool.Options.PolicyRepo)
fmt.Println("create pull requests for your policies.")
fmt.Println()
fmt.Println("Would you like to create the fork in your GitHub account now?")
fmt.Println()

_, s, err := util.Ask("Type 'yes' if you want to continue?", "yes|no|no", 3)
if err != nil {
return err
}
if !s {
return errors.New("no policy repo found and creation declined")
}

if err := srctool.CreatePolicyRepoFork(context.Background()); err != nil {
return err
}
return nil
}
15 changes: 12 additions & 3 deletions internal/cmd/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ a fork of the repository you want to protect.
// Create a new sourcetool object
srctool, err := sourcetool.New(
sourcetool.WithAuthenticator(authenticator),
sourcetool.WithPolicyRepo(opts.policyRepo),
// Uncomment when we support other policy repo
// sourcetool.WithPolicyRepo(opts.policyRepo),
sourcetool.WithUserForkOrg(opts.userForkOrg),
sourcetool.WithEnforce(opts.enforce),
)
Expand All @@ -316,8 +317,16 @@ a fork of the repository you want to protect.
}
cs := []models.ControlConfiguration{}
if opts.interactive {
fmt.Println("\nsourcetool is about to perform the following actions on your behalf:")
fmt.Println("")
// Check if we need the policy fork
if slices.Contains(opts.configs, string(models.CONFIG_POLICY)) {
if err := ensureOrCreatePolicyFork(srctool); err != nil {
return err
}
}

fmt.Println()
fmt.Println("sourcetool is about to perform the following actions on your behalf:")
fmt.Println()

for _, c := range opts.configs {
cs = append(cs, models.ControlConfiguration(c))
Expand Down
35 changes: 35 additions & 0 deletions pkg/sourcetool/implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type toolImplementation interface {
GetBranchControls(context.Context, models.VcsBackend, *models.Repository, *models.Branch) (*slsa.ControlSetStatus, error)
ConfigureControls(models.VcsBackend, *models.Repository, []*models.Branch, []models.ControlConfiguration) error
GetPolicyStatus(context.Context, *auth.Authenticator, *options.Options, *models.Repository) (*slsa.ControlStatus, error)
CreateRepositoryFork(context.Context, *auth.Authenticator, *models.Repository, string) error
}

type defaultToolImplementation struct{}
Expand Down Expand Up @@ -280,3 +281,37 @@ func (impl *defaultToolImplementation) GetPolicyStatus(
},
}, nil
}

// CreateRepositoryFork creates a fork of a repo into the logged-in user's org.
// Optionally the fork can have a different name than the original.
func (impl *defaultToolImplementation) CreateRepositoryFork(
ctx context.Context, a *auth.Authenticator, src *models.Repository, forkName string,
) error {
client, err := a.GetGitHubClient()
if err != nil {
return fmt.Errorf("creating GitHub client: %w", err)
}

srcOrg, srcName, err := src.PathAsGitHubOwnerName()
if err != nil {
return err
}

if forkName == "" {
forkName = srcName
}

// Create the fork
_, resp, err := client.Repositories.CreateFork(
ctx, srcOrg, srcName, &github.RepositoryCreateForkOptions{
Name: forkName,
},
)

// GitHub will return 202 for larger repos that are cloned async
if err != nil && resp.StatusCode != 202 {
return fmt.Errorf("creating repository fork: %w", err)
}

return nil
}
78 changes: 78 additions & 0 deletions pkg/sourcetool/sourcetoolfakes/fake_tool_implementation.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 19 additions & 4 deletions pkg/sourcetool/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (t *Tool) OnboardRepository(repo *models.Repository, branches []*models.Bra
return fmt.Errorf("verifying options: %w", err)
}

if err = backend.ConfigureControls(
if err := backend.ConfigureControls(
repo, branches, []models.ControlConfiguration{
models.CONFIG_BRANCH_RULES, models.CONFIG_GEN_PROVENANCE, models.CONFIG_TAG_RULES,
},
Expand Down Expand Up @@ -155,15 +155,19 @@ func (t *Tool) FindPolicyPR(repo *models.Repository) (*models.PullRequest, error
return pr, nil
}

func (t *Tool) CheckPolicyRepoFork(repo *models.Repository) (bool, error) {
// CheckPolicyRepoFork checks that the logged in user has a fork
// of the configured policy repo.
func (t *Tool) CheckPolicyRepoFork() (bool, error) {
if err := t.impl.CheckPolicyFork(&t.Options); err != nil {
if strings.Contains(err.Error(), "404 Not Found") {
return false, nil
}
if strings.Contains(err.Error(), "oes not have a fork of") {
return false, nil
}
return false, err
} else {
return true, nil
}
return true, nil
}

// CreateBranchPolicy creates a repository policy
Expand Down Expand Up @@ -248,3 +252,14 @@ func (t *Tool) CreateRepositoryPolicy(ctx context.Context, r *models.Repository,
}
return pcy, pr, nil
}

// CreatePolicyRepoFork creates a fork of the policy repository in the user's GitHub org
func (t *Tool) CreatePolicyRepoFork(ctx context.Context) error {
err := t.impl.CreateRepositoryFork(ctx, t.Authenticator, &models.Repository{
Path: t.Options.PolicyRepo,
}, "")
if err != nil {
return fmt.Errorf("creating policy repo fork: %w", err)
}
return nil
}
95 changes: 94 additions & 1 deletion pkg/sourcetool/tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func TestCheckPolicyFork(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tool := Tool{impl: tc.prepare(t)}
found, err := tool.CheckPolicyRepoFork(&models.Repository{})
found, err := tool.CheckPolicyRepoFork()
if tc.mustErr {
require.Error(t, err)
return
Expand All @@ -251,3 +251,96 @@ func TestCheckPolicyFork(t *testing.T) {
})
}
}

func TestCreatePolicyRepoFork(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
getSut func(t *testing.T) toolImplementation
mustErr bool
}{
{"normal", func(t *testing.T) toolImplementation {
t.Helper()
timp := &sourcetoolfakes.FakeToolImplementation{}
timp.CreateRepositoryForkReturns(nil)
return timp
}, false},
{"create-fork-fails", func(t *testing.T) toolImplementation {
t.Helper()
timp := &sourcetoolfakes.FakeToolImplementation{}
timp.CreateRepositoryForkReturns(errors.New("error"))
return timp
}, true},
} {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
impl := tt.getSut(t)
tool, err := New()
require.NoError(t, err)
tool.impl = impl
err = tool.CreatePolicyRepoFork(t.Context())
if tt.mustErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}

func TestOnboardRepository(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
getSut func(t *testing.T) toolImplementation
mustErr bool
}{
{"normal", func(t *testing.T) toolImplementation {
t.Helper()
timp := &sourcetoolfakes.FakeToolImplementation{}
timp.GetVcsBackendReturns(&modelsfakes.FakeVcsBackend{}, nil)
timp.VerifyOptionsForFullOnboardReturns(nil)
timp.ConfigureControlsReturns(nil)
return timp
}, false},
{"get-vcs-fails", func(t *testing.T) toolImplementation {
t.Helper()
timp := &sourcetoolfakes.FakeToolImplementation{}
timp.GetVcsBackendReturns(nil, errors.New("vcsborked"))
timp.VerifyOptionsForFullOnboardReturns(nil)
timp.ConfigureControlsReturns(nil)
return timp
}, true},
{"get-verifyoptions-fails", func(t *testing.T) toolImplementation {
t.Helper()
timp := &sourcetoolfakes.FakeToolImplementation{}
timp.GetVcsBackendReturns(&modelsfakes.FakeVcsBackend{}, nil)
timp.VerifyOptionsForFullOnboardReturns(errors.New("onboarderr"))
return timp
}, true},
{"backend-configure-controls-fails", func(t *testing.T) toolImplementation {
t.Helper()
bend := modelsfakes.FakeVcsBackend{}
bend.ConfigureControlsReturns(errors.New("configure-error"))

timp := &sourcetoolfakes.FakeToolImplementation{}
timp.GetVcsBackendReturns(&bend, nil)
timp.VerifyOptionsForFullOnboardReturns(nil)
return timp
}, true},
} {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
impl := tt.getSut(t)
tool, err := New()
require.NoError(t, err)
tool.impl = impl
err = tool.OnboardRepository(&models.Repository{Path: "example/repo"}, []*models.Branch{{Name: "main"}})
if tt.mustErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
Loading