diff --git a/internal/cmd/status.go b/internal/cmd/status.go index 64178ae9..b688bb34 100644 --- a/internal/cmd/status.go +++ b/internal/cmd/status.go @@ -4,6 +4,7 @@ package cmd import ( + "context" "errors" "fmt" "strings" @@ -146,11 +147,22 @@ sourcetool status myorg/myrepo@mybranch } fmt.Println() + policyNeedsUpdate := false if policyControlStatus != nil { fmt.Printf("%-35s ", "Repo policy found:") switch policyControlStatus.State { case slsa.StateActive: - fmt.Println("✅") + fmt.Print("✅") + // Check if the policy needs updating + pcy, err := srctool.GetRepositoryPolicy(context.Background(), opts.GetRepository()) + if err == nil { + pb := pcy.GetBranchPolicy(opts.GetBranch().Name) + if pb != nil && pb.GetTargetSlsaSourceLevel() != string(toplevel) { + fmt.Print(w2(fmt.Sprintf(" (needs update to %s)", toplevel))) + policyNeedsUpdate = true + } + } + fmt.Println() case slsa.StateNotEnabled: fmt.Println("🚫") case slsa.StateInProgress: @@ -188,6 +200,15 @@ sourcetool status myorg/myrepo@mybranch fmt.Println() } + if policyNeedsUpdate { + if !titled { + fmt.Println(w2("✨ Recommended actions:")) + } + fmt.Println(" - Update the repository source policy") + fmt.Printf(" > sourcetool policy create --update %s\n", opts.GetRepository().Path) + fmt.Println() + } + return nil }, } diff --git a/pkg/policy/policy.go b/pkg/policy/policy.go index dd336412..3189fc98 100644 --- a/pkg/policy/policy.go +++ b/pkg/policy/policy.go @@ -38,7 +38,8 @@ const ( ) // Returns the policy for the branch or nil if the branch doesn't have one. -func (rp *RepoPolicy) getBranchPolicy(branch string) *ProtectedBranch { +func (rp *RepoPolicy) GetBranchPolicy(branch string) *ProtectedBranch { + branch = strings.TrimPrefix(branch, "refs/heads/") for _, pb := range rp.GetProtectedBranches() { if pb.GetName() == branch { return pb @@ -525,7 +526,7 @@ func (pe *PolicyEvaluator) EvaluateControl(ctx context.Context, repo *models.Rep return slsa.SourceVerifiedLevels{}, "", err } - branchPolicy := rp.getBranchPolicy(branch.Name) + branchPolicy := rp.GetBranchPolicy(branch.Name) if branchPolicy == nil { branchPolicy = createDefaultBranchPolicy(branch) policyPath = "DEFAULT" @@ -555,7 +556,7 @@ func (pe *PolicyEvaluator) EvaluateSourceProv(ctx context.Context, repo *models. return slsa.SourceVerifiedLevels{}, "", err } - branchPolicy := rp.getBranchPolicy(branch.Name) + branchPolicy := rp.GetBranchPolicy(branch.Name) if branchPolicy == nil { branchPolicy = createDefaultBranchPolicy(branch) policyPath = "DEFAULT" diff --git a/pkg/policy/policy_test.go b/pkg/policy/policy_test.go index 16de6213..704348dd 100644 --- a/pkg/policy/policy_test.go +++ b/pkg/policy/policy_test.go @@ -1443,7 +1443,7 @@ func assertPolicyResultEquals(t *testing.T, ctx context.Context, ghConn *ghcontr // TODO: check the rest of the contents of expectedPolicy? - gotPb := rp.getBranchPolicy(ghcontrol.GetBranchFromRef(ghConn.GetFullRef())) + gotPb := rp.GetBranchPolicy(ghcontrol.GetBranchFromRef(ghConn.GetFullRef())) if expectedBranchPolicy == nil { if gotPb != nil {