diff --git a/internal/cmd/policy.go b/internal/cmd/policy.go index 73a34031..c12fa6bc 100644 --- a/internal/cmd/policy.go +++ b/internal/cmd/policy.go @@ -6,6 +6,7 @@ package cmd import ( "context" "encoding/json" + "errors" "fmt" "os" @@ -191,6 +192,10 @@ just print the generated policy. } if opts.openPullRequest && opts.interactive { + if err := ensureOrCreatePolicyFork(srctool); err != nil { + return err + } + fmt.Printf(` sourcetool is about to perform the following actions on your behalf: @@ -230,7 +235,7 @@ open the pull request from there. if opts.openPullRequest && pr != nil { fmt.Fprintf(os.Stderr, "\n") - fmt.Fprintf(os.Stderr, "Opened pull request: https://github.com/%s/pulls/%d\n\n", pr.Repo.Path, pr.Number) + fmt.Fprintf(os.Stderr, "pull request open: https://github.com/%s/pull/%d\n\n", pr.Repo.Path, pr.Number) } return nil @@ -260,3 +265,37 @@ func displayPolicy(opts repoOptions, pcy *policy.RepoPolicy) error { fmt.Println() return nil } + +// ensureOrCreatePolicyFork checks the user has a fork of the policy repo. +// In case they do not, asks to create a fork in their GitHub account. +func ensureOrCreatePolicyFork(srctool *sourcetool.Tool) error { + found, err := srctool.CheckPolicyRepoFork() + if err != nil { + return fmt.Errorf("checking for policy repo fork: %w", err) + } + if found { + return nil + } + + fmt.Println() + fmt.Println() + fmt.Printf("%s sourcetool could not find a fork of the community source\n", w("Note:")) + fmt.Printf("policy repository (%s). We need it to\n", srctool.Options.PolicyRepo) + fmt.Println("create pull requests for your policies.") + fmt.Println() + fmt.Println("Would you like to create the fork in your GitHub account now?") + fmt.Println() + + _, s, err := util.Ask("Type 'yes' if you want to continue?", "yes|no|no", 3) + if err != nil { + return err + } + if !s { + return errors.New("no policy repo found and creation declined") + } + + if err := srctool.CreatePolicyRepoFork(context.Background()); err != nil { + return err + } + return nil +} diff --git a/internal/cmd/setup.go b/internal/cmd/setup.go index 465b26bb..4e269643 100644 --- a/internal/cmd/setup.go +++ b/internal/cmd/setup.go @@ -307,7 +307,8 @@ a fork of the repository you want to protect. // Create a new sourcetool object srctool, err := sourcetool.New( sourcetool.WithAuthenticator(authenticator), - sourcetool.WithPolicyRepo(opts.policyRepo), + // Uncomment when we support other policy repo + // sourcetool.WithPolicyRepo(opts.policyRepo), sourcetool.WithUserForkOrg(opts.userForkOrg), sourcetool.WithEnforce(opts.enforce), ) @@ -316,8 +317,16 @@ a fork of the repository you want to protect. } cs := []models.ControlConfiguration{} if opts.interactive { - fmt.Println("\nsourcetool is about to perform the following actions on your behalf:") - fmt.Println("") + // Check if we need the policy fork + if slices.Contains(opts.configs, string(models.CONFIG_POLICY)) { + if err := ensureOrCreatePolicyFork(srctool); err != nil { + return err + } + } + + fmt.Println() + fmt.Println("sourcetool is about to perform the following actions on your behalf:") + fmt.Println() for _, c := range opts.configs { cs = append(cs, models.ControlConfiguration(c)) diff --git a/pkg/sourcetool/implementation.go b/pkg/sourcetool/implementation.go index f7dee73c..507a614d 100644 --- a/pkg/sourcetool/implementation.go +++ b/pkg/sourcetool/implementation.go @@ -39,6 +39,7 @@ type toolImplementation interface { GetBranchControls(context.Context, models.VcsBackend, *models.Repository, *models.Branch) (*slsa.ControlSetStatus, error) ConfigureControls(models.VcsBackend, *models.Repository, []*models.Branch, []models.ControlConfiguration) error GetPolicyStatus(context.Context, *auth.Authenticator, *options.Options, *models.Repository) (*slsa.ControlStatus, error) + CreateRepositoryFork(context.Context, *auth.Authenticator, *models.Repository, string) error } type defaultToolImplementation struct{} @@ -280,3 +281,37 @@ func (impl *defaultToolImplementation) GetPolicyStatus( }, }, nil } + +// CreateRepositoryFork creates a fork of a repo into the logged-in user's org. +// Optionally the fork can have a different name than the original. +func (impl *defaultToolImplementation) CreateRepositoryFork( + ctx context.Context, a *auth.Authenticator, src *models.Repository, forkName string, +) error { + client, err := a.GetGitHubClient() + if err != nil { + return fmt.Errorf("creating GitHub client: %w", err) + } + + srcOrg, srcName, err := src.PathAsGitHubOwnerName() + if err != nil { + return err + } + + if forkName == "" { + forkName = srcName + } + + // Create the fork + _, resp, err := client.Repositories.CreateFork( + ctx, srcOrg, srcName, &github.RepositoryCreateForkOptions{ + Name: forkName, + }, + ) + + // GitHub will return 202 for larger repos that are cloned async + if err != nil && resp.StatusCode != 202 { + return fmt.Errorf("creating repository fork: %w", err) + } + + return nil +} diff --git a/pkg/sourcetool/sourcetoolfakes/fake_tool_implementation.go b/pkg/sourcetool/sourcetoolfakes/fake_tool_implementation.go index 2b480468..c6e46bb5 100644 --- a/pkg/sourcetool/sourcetoolfakes/fake_tool_implementation.go +++ b/pkg/sourcetool/sourcetoolfakes/fake_tool_implementation.go @@ -65,6 +65,20 @@ type FakeToolImplementation struct { result1 *models.PullRequest result2 error } + CreateRepositoryForkStub func(context.Context, *auth.Authenticator, *models.Repository, string) error + createRepositoryForkMutex sync.RWMutex + createRepositoryForkArgsForCall []struct { + arg1 context.Context + arg2 *auth.Authenticator + arg3 *models.Repository + arg4 string + } + createRepositoryForkReturns struct { + result1 error + } + createRepositoryForkReturnsOnCall map[int]struct { + result1 error + } GetAttestationReaderStub func(*models.Repository) (models.AttestationStorageReader, error) getAttestationReaderMutex sync.RWMutex getAttestationReaderArgsForCall []struct { @@ -418,6 +432,70 @@ func (fake *FakeToolImplementation) CreatePolicyPRReturnsOnCall(i int, result1 * }{result1, result2} } +func (fake *FakeToolImplementation) CreateRepositoryFork(arg1 context.Context, arg2 *auth.Authenticator, arg3 *models.Repository, arg4 string) error { + fake.createRepositoryForkMutex.Lock() + ret, specificReturn := fake.createRepositoryForkReturnsOnCall[len(fake.createRepositoryForkArgsForCall)] + fake.createRepositoryForkArgsForCall = append(fake.createRepositoryForkArgsForCall, struct { + arg1 context.Context + arg2 *auth.Authenticator + arg3 *models.Repository + arg4 string + }{arg1, arg2, arg3, arg4}) + stub := fake.CreateRepositoryForkStub + fakeReturns := fake.createRepositoryForkReturns + fake.recordInvocation("CreateRepositoryFork", []interface{}{arg1, arg2, arg3, arg4}) + fake.createRepositoryForkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeToolImplementation) CreateRepositoryForkCallCount() int { + fake.createRepositoryForkMutex.RLock() + defer fake.createRepositoryForkMutex.RUnlock() + return len(fake.createRepositoryForkArgsForCall) +} + +func (fake *FakeToolImplementation) CreateRepositoryForkCalls(stub func(context.Context, *auth.Authenticator, *models.Repository, string) error) { + fake.createRepositoryForkMutex.Lock() + defer fake.createRepositoryForkMutex.Unlock() + fake.CreateRepositoryForkStub = stub +} + +func (fake *FakeToolImplementation) CreateRepositoryForkArgsForCall(i int) (context.Context, *auth.Authenticator, *models.Repository, string) { + fake.createRepositoryForkMutex.RLock() + defer fake.createRepositoryForkMutex.RUnlock() + argsForCall := fake.createRepositoryForkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeToolImplementation) CreateRepositoryForkReturns(result1 error) { + fake.createRepositoryForkMutex.Lock() + defer fake.createRepositoryForkMutex.Unlock() + fake.CreateRepositoryForkStub = nil + fake.createRepositoryForkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeToolImplementation) CreateRepositoryForkReturnsOnCall(i int, result1 error) { + fake.createRepositoryForkMutex.Lock() + defer fake.createRepositoryForkMutex.Unlock() + fake.CreateRepositoryForkStub = nil + if fake.createRepositoryForkReturnsOnCall == nil { + fake.createRepositoryForkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.createRepositoryForkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeToolImplementation) GetAttestationReader(arg1 *models.Repository) (models.AttestationStorageReader, error) { fake.getAttestationReaderMutex.Lock() ret, specificReturn := fake.getAttestationReaderReturnsOnCall[len(fake.getAttestationReaderArgsForCall)] diff --git a/pkg/sourcetool/tool.go b/pkg/sourcetool/tool.go index d4e37387..d9f9cad7 100644 --- a/pkg/sourcetool/tool.go +++ b/pkg/sourcetool/tool.go @@ -87,7 +87,7 @@ func (t *Tool) OnboardRepository(repo *models.Repository, branches []*models.Bra return fmt.Errorf("verifying options: %w", err) } - if err = backend.ConfigureControls( + if err := backend.ConfigureControls( repo, branches, []models.ControlConfiguration{ models.CONFIG_BRANCH_RULES, models.CONFIG_GEN_PROVENANCE, models.CONFIG_TAG_RULES, }, @@ -155,15 +155,19 @@ func (t *Tool) FindPolicyPR(repo *models.Repository) (*models.PullRequest, error return pr, nil } -func (t *Tool) CheckPolicyRepoFork(repo *models.Repository) (bool, error) { +// CheckPolicyRepoFork checks that the logged in user has a fork +// of the configured policy repo. +func (t *Tool) CheckPolicyRepoFork() (bool, error) { if err := t.impl.CheckPolicyFork(&t.Options); err != nil { if strings.Contains(err.Error(), "404 Not Found") { return false, nil } + if strings.Contains(err.Error(), "oes not have a fork of") { + return false, nil + } return false, err - } else { - return true, nil } + return true, nil } // CreateBranchPolicy creates a repository policy @@ -248,3 +252,14 @@ func (t *Tool) CreateRepositoryPolicy(ctx context.Context, r *models.Repository, } return pcy, pr, nil } + +// CreatePolicyRepoFork creates a fork of the policy repository in the user's GitHub org +func (t *Tool) CreatePolicyRepoFork(ctx context.Context) error { + err := t.impl.CreateRepositoryFork(ctx, t.Authenticator, &models.Repository{ + Path: t.Options.PolicyRepo, + }, "") + if err != nil { + return fmt.Errorf("creating policy repo fork: %w", err) + } + return nil +} diff --git a/pkg/sourcetool/tool_test.go b/pkg/sourcetool/tool_test.go index 2eb91f69..19870b93 100644 --- a/pkg/sourcetool/tool_test.go +++ b/pkg/sourcetool/tool_test.go @@ -241,7 +241,7 @@ func TestCheckPolicyFork(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() tool := Tool{impl: tc.prepare(t)} - found, err := tool.CheckPolicyRepoFork(&models.Repository{}) + found, err := tool.CheckPolicyRepoFork() if tc.mustErr { require.Error(t, err) return @@ -251,3 +251,96 @@ func TestCheckPolicyFork(t *testing.T) { }) } } + +func TestCreatePolicyRepoFork(t *testing.T) { + t.Parallel() + for _, tt := range []struct { + name string + getSut func(t *testing.T) toolImplementation + mustErr bool + }{ + {"normal", func(t *testing.T) toolImplementation { + t.Helper() + timp := &sourcetoolfakes.FakeToolImplementation{} + timp.CreateRepositoryForkReturns(nil) + return timp + }, false}, + {"create-fork-fails", func(t *testing.T) toolImplementation { + t.Helper() + timp := &sourcetoolfakes.FakeToolImplementation{} + timp.CreateRepositoryForkReturns(errors.New("error")) + return timp + }, true}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + impl := tt.getSut(t) + tool, err := New() + require.NoError(t, err) + tool.impl = impl + err = tool.CreatePolicyRepoFork(t.Context()) + if tt.mustErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestOnboardRepository(t *testing.T) { + t.Parallel() + for _, tt := range []struct { + name string + getSut func(t *testing.T) toolImplementation + mustErr bool + }{ + {"normal", func(t *testing.T) toolImplementation { + t.Helper() + timp := &sourcetoolfakes.FakeToolImplementation{} + timp.GetVcsBackendReturns(&modelsfakes.FakeVcsBackend{}, nil) + timp.VerifyOptionsForFullOnboardReturns(nil) + timp.ConfigureControlsReturns(nil) + return timp + }, false}, + {"get-vcs-fails", func(t *testing.T) toolImplementation { + t.Helper() + timp := &sourcetoolfakes.FakeToolImplementation{} + timp.GetVcsBackendReturns(nil, errors.New("vcsborked")) + timp.VerifyOptionsForFullOnboardReturns(nil) + timp.ConfigureControlsReturns(nil) + return timp + }, true}, + {"get-verifyoptions-fails", func(t *testing.T) toolImplementation { + t.Helper() + timp := &sourcetoolfakes.FakeToolImplementation{} + timp.GetVcsBackendReturns(&modelsfakes.FakeVcsBackend{}, nil) + timp.VerifyOptionsForFullOnboardReturns(errors.New("onboarderr")) + return timp + }, true}, + {"backend-configure-controls-fails", func(t *testing.T) toolImplementation { + t.Helper() + bend := modelsfakes.FakeVcsBackend{} + bend.ConfigureControlsReturns(errors.New("configure-error")) + + timp := &sourcetoolfakes.FakeToolImplementation{} + timp.GetVcsBackendReturns(&bend, nil) + timp.VerifyOptionsForFullOnboardReturns(nil) + return timp + }, true}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + impl := tt.getSut(t) + tool, err := New() + require.NoError(t, err) + tool.impl = impl + err = tool.OnboardRepository(&models.Repository{Path: "example/repo"}, []*models.Branch{{Name: "main"}}) + if tt.mustErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +}