diff --git a/.github/workflows/local_attest.yml b/.github/workflows/local_attest.yml index 00f9109b..524d364b 100644 --- a/.github/workflows/local_attest.yml +++ b/.github/workflows/local_attest.yml @@ -2,6 +2,7 @@ name: SLSA Source on: push: branches: [ "main" ] + tags: ['**'] jobs: # Whenever new source is pushed recompute the slsa source information. diff --git a/DESIGN.md b/DESIGN.md index 04774ca4..8944610a 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -233,7 +233,8 @@ This tool can also check to see if the GitHub repo/ref is configured to require ### IMMUTABLE_TAGS This tool can also check to see if the GitHub repo is configured to require -immutable tags. To do so it checks that the repo: +immutable tags. To do so it checks that the repo enables the follow rules +to ~ALL tags: 1. Doesn't allow tag updates 2. Doesn't allow tag deletions @@ -243,6 +244,11 @@ Importing [rulesets/tag_immutability.json](rulesets/tag_immutability.json) to a repos rulesets will enable the repo controls. The `immutable_tags` field in the policy then needs to be enabled too. +TODO: In the future this tool could be updated to allow some subset of tags +to be updated (e.g. `latest`, `nightly`), but that feature is not yet +supported. Tracked +[here](https://github.com/slsa-framework/slsa-source-poc/issues/129). + ## Open Issues ### Dealing with reliability diff --git a/actions/slsa_with_provenance/action.yml b/actions/slsa_with_provenance/action.yml index 7fee3810..4783015c 100644 --- a/actions/slsa_with_provenance/action.yml +++ b/actions/slsa_with_provenance/action.yml @@ -22,11 +22,18 @@ runs: - id: setup run: mkdir -p metadata shell: bash - - id: determine_level + - id: handle_branch_push + if: ${{ startsWith(github.ref, 'refs/heads/') }} run: | - echo "## SLSA Source Properties" >> $GITHUB_STEP_SUMMARY + echo "## SLSA Source Properties Branch Push" >> $GITHUB_STEP_SUMMARY go run github.com/slsa-framework/slsa-source-poc/sourcetool@8de659f119d933d4cfaed300e7d8bd78528a48c7 --github_token ${{ github.token }} checklevelprov --commit ${{ github.sha }} --owner ${{ github.repository_owner }} --repo ${{ github.event.repository.name }} --branch ${{ github.ref_name }} --output_signed_bundle ${{ github.workspace }}/metadata/signed_bundle.intoto.jsonl >> $GITHUB_STEP_SUMMARY shell: bash + - id: handle_tag_push + if: ${{ startsWith(github.ref, 'refs/tags/') }} + run: | + echo "## SLSA Source Properties Tag Push" >> $GITHUB_STEP_SUMMARY + echo "TODO" + shell: bash - id: summary run: | echo "## Signed Bundle" >> $GITHUB_STEP_SUMMARY diff --git a/go.work.sum b/go.work.sum index eae8fc82..f9da05fc 100644 --- a/go.work.sum +++ b/go.work.sum @@ -73,6 +73,7 @@ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kB github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= github.com/bufbuild/protocompile v0.10.0/go.mod h1:G9qQIQo0xZ6Uyj6CMNz0saGmx2so+KONo8/KrELABiY= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/buildkite/agent/v3 v3.81.0/go.mod h1:edJeyycODRxaFvpT22rDGwaQ5oa4eB8GjtbjgX5VpFw= github.com/buildkite/go-pipeline v0.13.1/go.mod h1:2HHqlSFTYgHFhzedJu0LhLs9n5c9XkYnHiQFVN5HE4U= github.com/buildkite/interpolate v0.1.3/go.mod h1:UNVe6A+UfiBNKbhAySrBbZFZFxQ+DXr9nWen6WVt/A8= @@ -180,7 +181,6 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3 github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA= github.com/googleapis/google-cloud-go-testing v0.0.0-20210719221736-1c9a4c676720/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= diff --git a/policy/github.com/TomHennen/Concordance/source-policy.json b/policy/github.com/TomHennen/Concordance/source-policy.json index b9879231..e91060d4 100644 --- a/policy/github.com/TomHennen/Concordance/source-policy.json +++ b/policy/github.com/TomHennen/Concordance/source-policy.json @@ -5,8 +5,11 @@ "Name": "master", "Since": "2025-03-23T18:08:43.25099739Z", "target_slsa_source_level": "SLSA_SOURCE_LEVEL_3", - "require_review": false, - "immutable_tags": true + "require_review": false } - ] + ], + "protected_tag": { + "Since": "2025-03-23T18:08:43.25099739Z", + "immutable_tags": true + } } diff --git a/sourcetool/cmd/checklevel.go b/sourcetool/cmd/checklevel.go index 8d0c83d5..7399ead1 100644 --- a/sourcetool/cmd/checklevel.go +++ b/sourcetool/cmd/checklevel.go @@ -41,22 +41,22 @@ func doCheckLevel(commit, owner, repo, branch, outputVsa, outputUnsignedVsa stri log.Fatal("Must set commit, owner, repo, and branch flags.") } - gh_connection := gh_control.NewGhConnection(owner, repo, branch).WithAuthToken(githubToken) + gh_connection := gh_control.NewGhConnection(owner, repo, gh_control.BranchToFullRef(branch)).WithAuthToken(githubToken) ctx := context.Background() - controlStatus, err := gh_connection.GetControls(ctx, commit) + controlStatus, err := gh_connection.GetBranchControls(ctx, commit, gh_connection.GetFullRef()) if err != nil { log.Fatal(err) } - pol := policy.NewPolicy() - pol.UseLocalPolicy = checkLevelProvArgs.useLocalPolicy - verifiedLevels, policyPath, err := pol.EvaluateControl(ctx, gh_connection, controlStatus) + pe := policy.NewPolicyEvaluator() + pe.UseLocalPolicy = checkLevelProvArgs.useLocalPolicy + verifiedLevels, policyPath, err := pe.EvaluateControl(ctx, gh_connection, controlStatus) if err != nil { log.Fatal(err) } fmt.Print(verifiedLevels) - unsignedVsa, err := attest.CreateUnsignedSourceVsa(gh_connection, commit, verifiedLevels, policyPath) + unsignedVsa, err := attest.CreateUnsignedSourceVsa(gh_connection.GetRepoUri(), gh_connection.GetFullRef(), commit, verifiedLevels, policyPath) if err != nil { log.Fatal(err) } diff --git a/sourcetool/cmd/checklevelprov.go b/sourcetool/cmd/checklevelprov.go index e2264b69..33a0e699 100644 --- a/sourcetool/cmd/checklevelprov.go +++ b/sourcetool/cmd/checklevelprov.go @@ -46,7 +46,7 @@ var ( func doCheckLevelProv(checkLevelProvArgs CheckLevelProvArgs) { gh_connection := - gh_control.NewGhConnection(checkLevelProvArgs.owner, checkLevelProvArgs.repo, checkLevelProvArgs.branch).WithAuthToken(githubToken) + gh_control.NewGhConnection(checkLevelProvArgs.owner, checkLevelProvArgs.repo, gh_control.BranchToFullRef(checkLevelProvArgs.branch)).WithAuthToken(githubToken) ctx := context.Background() prevCommit := checkLevelProvArgs.prevCommit @@ -58,22 +58,22 @@ func doCheckLevelProv(checkLevelProvArgs CheckLevelProvArgs) { } } - pa := attest.NewProvenanceAttestor(gh_connection, getVerificationOptions()) - prov, err := pa.CreateSourceProvenance(ctx, checkLevelProvArgs.prevBundlePath, checkLevelProvArgs.commit, prevCommit) + pa := attest.NewProvenanceAttestor(gh_connection, getVerifier()) + prov, err := pa.CreateSourceProvenance(ctx, checkLevelProvArgs.prevBundlePath, checkLevelProvArgs.commit, prevCommit, gh_connection.GetFullRef()) if err != nil { log.Fatal(err) } // check p against policy - pol := policy.NewPolicy() - pol.UseLocalPolicy = checkLevelProvArgs.useLocalPolicy - verifiedLevels, policyPath, err := pol.EvaluateProv(ctx, gh_connection, prov) + pe := policy.NewPolicyEvaluator() + pe.UseLocalPolicy = checkLevelProvArgs.useLocalPolicy + verifiedLevels, policyPath, err := pe.EvaluateSourceProv(ctx, gh_connection, prov) if err != nil { log.Fatal(err) } // create vsa - unsignedVsa, err := attest.CreateUnsignedSourceVsa(gh_connection, checkLevelProvArgs.commit, verifiedLevels, policyPath) + unsignedVsa, err := attest.CreateUnsignedSourceVsa(gh_connection.GetRepoUri(), gh_connection.GetFullRef(), checkLevelProvArgs.commit, verifiedLevels, policyPath) if err != nil { log.Fatal(err) } diff --git a/sourcetool/cmd/checktag.go b/sourcetool/cmd/checktag.go new file mode 100644 index 00000000..9f51bf18 --- /dev/null +++ b/sourcetool/cmd/checktag.go @@ -0,0 +1,108 @@ +/* +Copyright © 2025 NAME HERE +*/ +package cmd + +import ( + "context" + "log" + "os" + + "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/attest" + "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/gh_control" + "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/policy" + "github.com/spf13/cobra" + "google.golang.org/protobuf/encoding/protojson" +) + +type CheckTagArgs struct { + commit string + owner string + repo string + tagName string + outputSignedBundle string + useLocalPolicy string +} + +var ( + checkTagArgs CheckTagArgs + // checktagCmd represents the checktag command + checktagCmd = &cobra.Command{ + Use: "checktag", + Short: "Checks to see if the tag operation should be allowed and issues a VSA", + Run: func(cmd *cobra.Command, args []string) { + doCheckTag(checkTagArgs) + }, + } +) + +func doCheckTag(args CheckTagArgs) { + gh_connection := + gh_control.NewGhConnection(args.owner, args.repo, gh_control.TagToFullRef(args.tagName)).WithAuthToken(githubToken) + ctx := context.Background() + verifier := getVerifier() + + // Create tag provenance. + pa := attest.NewProvenanceAttestor(gh_connection, verifier) + prov, err := pa.CreateTagProvenance(ctx, args.commit, gh_control.TagToFullRef(args.tagName)) + if err != nil { + log.Fatal(err) + } + + // check p against policy + pe := policy.NewPolicyEvaluator() + pe.UseLocalPolicy = args.useLocalPolicy + verifiedLevels, policyPath, err := pe.EvaluateTagProv(ctx, gh_connection, prov) + if err != nil { + log.Fatal(err) + } + + // create vsa + unsignedVsa, err := attest.CreateUnsignedSourceVsa(gh_connection.GetRepoUri(), gh_connection.GetFullRef(), args.commit, verifiedLevels, policyPath) + if err != nil { + log.Fatal(err) + } + + unsignedProv, err := protojson.Marshal(prov) + if err != nil { + log.Fatal(err) + } + + if args.outputSignedBundle != "" { + f, err := os.OpenFile(args.outputSignedBundle, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + log.Fatal(err) + } + defer f.Close() + + signedProv, err := attest.Sign(string(unsignedProv)) + if err != nil { + log.Fatal(err) + } + + signedVsa, err := attest.Sign(unsignedVsa) + if err != nil { + log.Fatal(err) + } + + f.WriteString(signedProv) + f.WriteString("\n") + f.WriteString(signedVsa) + f.WriteString("\n") + } else { + log.Printf("unsigned prov: %s\n", unsignedProv) + log.Printf("unsigned vsa: %s\n", unsignedVsa) + } +} + +func init() { + rootCmd.AddCommand(checktagCmd) + + checktagCmd.Flags().StringVar(&checkTagArgs.commit, "commit", "", "The commit to check - required.") + checktagCmd.Flags().StringVar(&checkTagArgs.owner, "owner", "", "The GitHub repository owner - required.") + checktagCmd.Flags().StringVar(&checkTagArgs.repo, "repo", "", "The GitHub repository name - required.") + checktagCmd.Flags().StringVar(&checkTagArgs.tagName, "tag_name", "", "The name of the new tag - required.") + checktagCmd.Flags().StringVar(&checkTagArgs.outputSignedBundle, "output_signed_bundle", "", "The path to write a bundle of signed attestations.") + checktagCmd.Flags().StringVar(&checkTagArgs.useLocalPolicy, "use_local_policy", "", "UNSAFE: Use the policy at this local path instead of the official one.") + +} diff --git a/sourcetool/cmd/createpolicy.go b/sourcetool/cmd/createpolicy.go index cc6e7c9a..2ed6e6cc 100644 --- a/sourcetool/cmd/createpolicy.go +++ b/sourcetool/cmd/createpolicy.go @@ -35,7 +35,7 @@ var ( ) func doCreatePolicy(policyRepoPath, owner, repo, branch string) { - gh_connection := gh_control.NewGhConnection(owner, repo, branch).WithAuthToken(githubToken) + gh_connection := gh_control.NewGhConnection(owner, repo, gh_control.BranchToFullRef(branch)).WithAuthToken(githubToken) ctx := context.Background() outpath, err := policy.CreateLocalPolicy(ctx, gh_connection, policyRepoPath) if err != nil { diff --git a/sourcetool/cmd/prov.go b/sourcetool/cmd/prov.go index bbafabc9..26ab98bf 100644 --- a/sourcetool/cmd/prov.go +++ b/sourcetool/cmd/prov.go @@ -32,10 +32,10 @@ var ( ) func doProv(prevAttPath, commit, prevCommit, owner, repo, branch string) { - gh_connection := gh_control.NewGhConnection(owner, repo, branch).WithAuthToken(githubToken) + gh_connection := gh_control.NewGhConnection(owner, repo, gh_control.BranchToFullRef(branch)).WithAuthToken(githubToken) ctx := context.Background() - pa := attest.NewProvenanceAttestor(gh_connection, attest.DefaultVerifierOptions) - newProv, err := pa.CreateSourceProvenance(ctx, prevAttPath, commit, prevCommit) + pa := attest.NewProvenanceAttestor(gh_connection, getVerifier()) + newProv, err := pa.CreateSourceProvenance(ctx, prevAttPath, commit, prevCommit, gh_connection.GetFullRef()) if err != nil { log.Fatal(err) } diff --git a/sourcetool/cmd/root.go b/sourcetool/cmd/root.go index f301db25..b31da7c0 100644 --- a/sourcetool/cmd/root.go +++ b/sourcetool/cmd/root.go @@ -31,7 +31,7 @@ to quickly create a Cobra application.`, } ) -func getVerificationOptions() attest.VerificationOptions { +func getVerifier() attest.Verifier { options := attest.DefaultVerifierOptions if checkLevelProvArgs.expectedIssuer != "" { options.ExpectedIssuer = checkLevelProvArgs.expectedIssuer @@ -39,7 +39,7 @@ func getVerificationOptions() attest.VerificationOptions { if checkLevelProvArgs.expectedSan != "" { options.ExpectedSan = checkLevelProvArgs.expectedSan } - return options + return attest.NewBndVerifier(options) } // Execute adds all child commands to the root command and sets flags appropriately. diff --git a/sourcetool/cmd/verifycommit.go b/sourcetool/cmd/verifycommit.go index c3f2f82d..0c142ac0 100644 --- a/sourcetool/cmd/verifycommit.go +++ b/sourcetool/cmd/verifycommit.go @@ -34,12 +34,10 @@ func doVerifyCommit(commit, owner, repo, branch string) { log.Fatal("Must set commit, owner, repo, and branch flags.") } - gh_connection := gh_control.NewGhConnection(owner, repo, branch).WithAuthToken(githubToken) + gh_connection := gh_control.NewGhConnection(owner, repo, gh_control.BranchToFullRef(branch)).WithAuthToken(githubToken) ctx := context.Background() - pa := attest.NewProvenanceAttestor(gh_connection, getVerificationOptions()) - - _, vsaPred, err := pa.GetVsa(ctx, commit) + _, vsaPred, err := attest.GetVsa(ctx, gh_connection, getVerifier(), commit, gh_connection.GetFullRef()) if err != nil { log.Fatal(err) } diff --git a/sourcetool/go.mod b/sourcetool/go.mod index 70ba6318..301ad6c3 100644 --- a/sourcetool/go.mod +++ b/sourcetool/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-git/go-git/v5 v5.13.2 github.com/google/go-github/v69 v69.2.0 github.com/in-toto/attestation v1.1.1 + github.com/migueleliasweb/go-github-mock v1.3.0 github.com/sigstore/sigstore-go v0.7.0 github.com/spf13/cobra v1.9.1 google.golang.org/protobuf v1.36.5 @@ -47,8 +48,10 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/certificate-transparency-go v1.3.1 // indirect github.com/google/go-containerregistry v0.20.3 // indirect + github.com/google/go-github/v71 v71.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/mux v1.8.1 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect github.com/hashicorp/hcl v1.0.1-vault-5 // indirect @@ -109,6 +112,7 @@ require ( golang.org/x/sys v0.30.0 // indirect golang.org/x/term v0.29.0 // indirect golang.org/x/text v0.22.0 // indirect + golang.org/x/time v0.9.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241219192143-6b3ec007d9bb // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect diff --git a/sourcetool/go.sum b/sourcetool/go.sum index 8dad229b..e90259e1 100644 --- a/sourcetool/go.sum +++ b/sourcetool/go.sum @@ -177,12 +177,14 @@ github.com/google/certificate-transparency-go v1.3.1 h1:akbcTfQg0iZlANZLn0L9xOeW github.com/google/certificate-transparency-go v1.3.1/go.mod h1:gg+UQlx6caKEDQ9EElFOujyxEQEfOiQzAt6782Bvi8k= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= github.com/google/go-github/v69 v69.2.0 h1:wR+Wi/fN2zdUx9YxSmYE0ktiX9IAR/BeePzeaUUbEHE= github.com/google/go-github/v69 v69.2.0/go.mod h1:xne4jymxLR6Uj9b7J7PyTpkMYstEMMwGZa0Aehh1azM= +github.com/google/go-github/v71 v71.0.0 h1:Zi16OymGKZZMm8ZliffVVJ/Q9YZreDKONCr+WUd0Z30= +github.com/google/go-github/v71 v71.0.0/go.mod h1:URZXObp2BLlMjwu0O8g4y6VBneUj2bCHgnI8FfgZ51M= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -199,6 +201,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gT github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= @@ -275,6 +279,8 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= +github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= diff --git a/sourcetool/pkg/attest/provenance.go b/sourcetool/pkg/attest/provenance.go index b1ceea04..525e67eb 100644 --- a/sourcetool/pkg/attest/provenance.go +++ b/sourcetool/pkg/attest/provenance.go @@ -20,6 +20,7 @@ import ( ) const SourceProvPredicateType = "https://github.com/slsa-framework/slsa-source-poc/source-provenance/v1-draft" +const TagProvPredicateType = "https://github.com/slsa-framework/slsa-source-poc/tag-provenance/v1-draft" // The predicate that encodes source provenance data. // The git commit this corresponds to is encoded in the surrounding statement. @@ -37,16 +38,33 @@ type SourceProvenancePred struct { Controls slsa_types.Controls `json:"controls"` } +// Summary of a summary +type VsaSummary struct { + SourceRefs []string `json:"source_refs"` + VerifiedLevels []string `json:"verifiedLevels"` +} + +type TagProvenancePred struct { + RepoUri string `json:"repo_uri"` + ActivityType string `json:"activity_type"` + Actor string `json:"actor"` + Tag string `json:"tag"` + CreatedOn time.Time `json:"created_on"` + // The tag related controls enabled at the time this tag was created/updated. + Controls slsa_types.Controls `json:"controls"` + VsaSummaries []VsaSummary `json:"vsa_summaries"` +} + type ProvenanceAttestor struct { - verification_options VerificationOptions - gh_connection *gh_control.GitHubConnection + verifier Verifier + gh_connection *gh_control.GitHubConnection } -func NewProvenanceAttestor(gh_connection *gh_control.GitHubConnection, verification_options VerificationOptions) *ProvenanceAttestor { - return &ProvenanceAttestor{verification_options: verification_options, gh_connection: gh_connection} +func NewProvenanceAttestor(gh_connection *gh_control.GitHubConnection, verifier Verifier) *ProvenanceAttestor { + return &ProvenanceAttestor{verifier: verifier, gh_connection: gh_connection} } -func GetProvPred(statement *spb.Statement) (*SourceProvenancePred, error) { +func GetSourceProvPred(statement *spb.Statement) (*SourceProvenancePred, error) { if statement == nil { return nil, errors.New("nil statement") } @@ -65,18 +83,42 @@ func GetProvPred(statement *spb.Statement) (*SourceProvenancePred, error) { // Using regular json.Unmarshal because this is just a regular struct. err = json.Unmarshal(predJson, &predStruct) if err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("unmarshaling predicate: %w", err) } // It's valid for Controls to be empty if no controls are reported. // The policy evaluation logic will determine if this is acceptable. // For example, a policy might only require SLSA Level 1, which has no specific control requirements from this predicate. - // if len(predStruct.Controls) == 0 { - // return nil, fmt.Errorf("expected %v to have non-zero properties", predStruct) - // } return &predStruct, nil } -func addPredToStatement(provPred *SourceProvenancePred, commit string) (*spb.Statement, error) { +func GetTagProvPred(statement *spb.Statement) (*TagProvenancePred, error) { + if statement == nil { + return nil, errors.New("nil statement") + } + if statement.PredicateType != TagProvPredicateType { + return nil, fmt.Errorf("unsupported predicate type: %s", statement.PredicateType) + } + if statement.Predicate == nil { + return nil, errors.New("nil predicate in statement") + } + predJson, err := protojson.Marshal(statement.Predicate) + if err != nil { + return nil, fmt.Errorf("cannot marshal predicate to JSON: %w", err) + } + + var predStruct TagProvenancePred + // Using regular json.Unmarshal because this is just a regular struct. + err = json.Unmarshal(predJson, &predStruct) + if err != nil { + return nil, fmt.Errorf("unmarshaling predicate: %w", err) + } + // It's valid for Controls to be empty if no controls are reported. + // The policy evaluation logic will determine if this is acceptable. + // For example, a policy might only require SLSA Level 1, which has no specific control requirements from this predicate. + return &predStruct, nil +} + +func addPredToStatement(provPred any, predicateType, commit string) (*spb.Statement, error) { // Using regular json.Marshal because this is just a regular struct and not from a proto. predJson, err := json.Marshal(provPred) if err != nil { @@ -96,7 +138,7 @@ func addPredToStatement(provPred *SourceProvenancePred, commit string) (*spb.Sta statementPb := spb.Statement{ Type: spb.StatementTypeUri, Subject: sub, - PredicateType: SourceProvPredicateType, + PredicateType: predicateType, Predicate: &predPb, } @@ -104,8 +146,8 @@ func addPredToStatement(provPred *SourceProvenancePred, commit string) (*spb.Sta } // Create provenance for the current commit without any context from the previous provenance (if any). -func (pa ProvenanceAttestor) createCurrentProvenance(ctx context.Context, commit, prevCommit string) (*spb.Statement, error) { - controlStatus, err := pa.gh_connection.GetControls(ctx, commit) +func (pa ProvenanceAttestor) createCurrentProvenance(ctx context.Context, commit, prevCommit, ref string) (*spb.Statement, error) { + controlStatus, err := pa.gh_connection.GetBranchControls(ctx, commit, ref) if err != nil { return nil, err } @@ -117,18 +159,18 @@ func (pa ProvenanceAttestor) createCurrentProvenance(ctx context.Context, commit curProvPred.RepoUri = pa.gh_connection.GetRepoUri() curProvPred.Actor = controlStatus.ActorLogin curProvPred.ActivityType = controlStatus.ActivityType - curProvPred.Branch = pa.gh_connection.GetFullBranch() + curProvPred.Branch = ref curProvPred.CreatedOn = curTime curProvPred.Controls = controlStatus.Controls // At the very least provenance is available starting now. :) curProvPred.Controls.AddControl(&slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: curTime}) - return addPredToStatement(&curProvPred, commit) + return addPredToStatement(&curProvPred, SourceProvPredicateType, commit) } // Gets provenance for the commit from git notes. -func (pa ProvenanceAttestor) GetProvenance(ctx context.Context, commit string) (*spb.Statement, *SourceProvenancePred, error) { +func (pa ProvenanceAttestor) GetProvenance(ctx context.Context, commit, ref string) (*spb.Statement, *SourceProvenancePred, error) { notes, err := pa.gh_connection.GetNotesForCommit(ctx, commit) if notes == "" { log.Printf("didn't find notes for commit %s", commit) @@ -139,12 +181,12 @@ func (pa ProvenanceAttestor) GetProvenance(ctx context.Context, commit string) ( log.Fatal(err) } - bundleReader := NewBundleReader(bufio.NewReader(strings.NewReader(notes)), pa.verification_options) + bundleReader := NewBundleReader(bufio.NewReader(strings.NewReader(notes)), pa.verifier) - return pa.getProvFromReader(bundleReader, commit) + return pa.getProvFromReader(bundleReader, commit, ref) } -func (pa ProvenanceAttestor) getProvFromReader(reader *BundleReader, commit string) (*spb.Statement, *SourceProvenancePred, error) { +func (pa ProvenanceAttestor) getProvFromReader(reader *BundleReader, commit, ref string) (*spb.Statement, *SourceProvenancePred, error) { for { stmt, err := reader.ReadStatement(MatchesTypeAndCommit(SourceProvPredicateType, commit)) if err != nil { @@ -161,46 +203,46 @@ func (pa ProvenanceAttestor) getProvFromReader(reader *BundleReader, commit stri break } - prevProdPred, err := GetProvPred(stmt) + prevProdPred, err := GetSourceProvPred(stmt) if err != nil { return nil, nil, err } - if prevProdPred.Branch == pa.gh_connection.GetFullBranch() { + if ref == gh_control.AnyReference || prevProdPred.Branch == ref { // Should be good! return stmt, prevProdPred, nil } else { - log.Printf("prov '%v' does not reference commit '%s' for branch '%s', skipping", stmt, commit, pa.gh_connection.GetFullBranch()) + log.Printf("prov '%v' does not reference commit '%s' for branch '%s', skipping", stmt, commit, ref) } } - log.Printf("didn't find commit %s for branch %s", commit, pa.gh_connection.Branch) + log.Printf("didn't find commit %s for ref %s", commit, ref) return nil, nil, nil } -func (pa ProvenanceAttestor) getPrevProvenance(ctx context.Context, prevAttPath, prevCommit string) (*spb.Statement, *SourceProvenancePred, error) { +func (pa ProvenanceAttestor) getPrevProvenance(ctx context.Context, prevAttPath, prevCommit, ref string) (*spb.Statement, *SourceProvenancePred, error) { if prevAttPath != "" { f, err := os.Open(prevAttPath) if err != nil { return nil, nil, err } - return pa.getProvFromReader(NewBundleReader(bufio.NewReader(f), pa.verification_options), prevCommit) + return pa.getProvFromReader(NewBundleReader(bufio.NewReader(f), pa.verifier), prevCommit, ref) } // Try to get the previous bundle ourselves... - return pa.GetProvenance(ctx, prevCommit) + return pa.GetProvenance(ctx, prevCommit, ref) } -func (pa ProvenanceAttestor) CreateSourceProvenance(ctx context.Context, prevAttPath, commit, prevCommit string) (*spb.Statement, error) { +func (pa ProvenanceAttestor) CreateSourceProvenance(ctx context.Context, prevAttPath, commit, prevCommit, ref string) (*spb.Statement, error) { // Source provenance is based on // 1. The current control situation (we assume 'commit' has _just_ occurred). // 2. How long the properties have been enforced according to the previous provenance. - curProv, err := pa.createCurrentProvenance(ctx, commit, prevCommit) + curProv, err := pa.createCurrentProvenance(ctx, commit, prevCommit, ref) if err != nil { return nil, err } - prevProvStmt, prevProvPred, err := pa.getPrevProvenance(ctx, prevAttPath, prevCommit) + prevProvStmt, prevProvPred, err := pa.getPrevProvenance(ctx, prevAttPath, prevCommit, ref) if err != nil { return nil, err } @@ -211,7 +253,7 @@ func (pa ProvenanceAttestor) CreateSourceProvenance(ctx context.Context, prevAtt return curProv, nil } - curProvPred, err := GetProvPred(curProv) + curProvPred, err := GetSourceProvPred(curProv) if err != nil { return nil, err } @@ -229,5 +271,52 @@ func (pa ProvenanceAttestor) CreateSourceProvenance(ctx context.Context, prevAtt curProvPred.Controls[i] = curControl } - return addPredToStatement(curProvPred, commit) + return addPredToStatement(curProvPred, SourceProvPredicateType, commit) +} + +func (pa ProvenanceAttestor) CreateTagProvenance(ctx context.Context, commit, ref string) (*spb.Statement, error) { + // 1. Check that the immutable tags control is still enabled and how long it's been enabled, store it in the prov. + // 2. Get a VSA associated with this commit, if any. + // 3. Record the levels and branches covered by that VSA in the provenance. + + controlStatus, err := pa.gh_connection.GetTagControls(ctx, commit, ref) + if err != nil { + return nil, err + } + + // Find the most recent VSA for this commit. Any reference is OK. + // TODO: in the future get all of them. + // TODO: we should actually verify this vsa: https://github.com/slsa-framework/slsa-source-poc/issues/148 + vsaStatement, vsaPred, err := GetVsa(ctx, pa.gh_connection, pa.verifier, commit, gh_control.AnyReference) + if err != nil { + return nil, fmt.Errorf("error fetching VSA when creating tag provenance %w", err) + } + if vsaPred == nil { + // TODO: If there's not a VSA should we still issue provenance? + return nil, nil + } + + curTime := time.Now() + + vsaRefs, err := GetSourceRefsForCommit(vsaStatement, commit) + if err != nil { + return nil, fmt.Errorf("error getting source refs from vsa %w", err) + } + + curProvPred := TagProvenancePred{ + RepoUri: pa.gh_connection.GetRepoUri(), + Actor: controlStatus.ActorLogin, + ActivityType: controlStatus.ActivityType, + Tag: ref, + CreatedOn: curTime, + Controls: controlStatus.Controls, + VsaSummaries: []VsaSummary{ + { + SourceRefs: vsaRefs, + VerifiedLevels: vsaPred.VerifiedLevels, + }, + }, + } + + return addPredToStatement(&curProvPred, TagProvPredicateType, commit) } diff --git a/sourcetool/pkg/attest/provenance_test.go b/sourcetool/pkg/attest/provenance_test.go new file mode 100644 index 00000000..1d09b8ae --- /dev/null +++ b/sourcetool/pkg/attest/provenance_test.go @@ -0,0 +1,181 @@ +package attest + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/gh_control" + "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/slsa_types" + "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/testsupport" +) + +var rulesetOldTime = time.Now().Add(-time.Hour) + +func rulesForTagImmutability() *github.RepositoryRulesetRules { + return &github.RepositoryRulesetRules{ + Update: &github.UpdateRuleParameters{}, + Deletion: &github.EmptyRuleParameters{}, + NonFastForward: &github.EmptyRuleParameters{}, + } +} + +func conditionsForTagImmutability() *github.RepositoryRulesetConditions { + return &github.RepositoryRulesetConditions{ + RefName: &github.RepositoryRulesetRefConditionParameters{ + Include: []string{"~ALL"}, + }, + } +} + +func createTestVsa(t *testing.T, repoUri, ref, commit string, verifiedLevels slsa_types.SourceVerifiedLevels) string { + vsa, err := CreateUnsignedSourceVsa(repoUri, ref, commit, verifiedLevels, "test-policy") + if err != nil { + t.Fatalf("failure creating test vsa: %v", err) + } + return vsa +} + +func newNotesContent(content string) *github.RepositoryContent { + return &github.RepositoryContent{ + Content: github.Ptr(content), + } +} + +func newImmutableTagsRulesetsResponse(id int64, target github.RulesetTarget, enforcement github.RulesetEnforcement, + updatedAt time.Time) *github.RepositoryRuleset { + return &github.RepositoryRuleset{ + ID: github.Ptr(id), + Target: github.Ptr(target), + Enforcement: enforcement, + UpdatedAt: github.Ptr(github.Timestamp{Time: updatedAt}), + Rules: rulesForTagImmutability(), + Conditions: conditionsForTagImmutability(), + } +} + +func newMockedGitHubClient(rulesetResponse *github.RepositoryRuleset, notesContent *github.RepositoryContent) *github.Client { + return github.NewClient(mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposRulesetsByOwnerByRepo, + []*github.RepositoryRuleset{ + rulesetResponse, + }, + ), + mock.WithRequestMatch( + mock.GetReposRulesetsByOwnerByRepoByRulesetId, + *rulesetResponse, + ), + mock.WithRequestMatch( + mock.GetReposContentsByOwnerByRepoByPath, + *notesContent, + ), + )) +} + +// Helper to create a test GH Branch connection with no client. +func newTestGhConnection(owner, repo, branch string, rulesetResponse *github.RepositoryRuleset, notesContent *github.RepositoryContent) *gh_control.GitHubConnection { + return gh_control.NewGhConnectionWithClient( + owner, repo, gh_control.BranchToFullRef(branch), + newMockedGitHubClient(rulesetResponse, notesContent)) +} + +func timesEqualWithinMargin(t1, t2 time.Time, margin time.Duration) bool { + diff := t1.Sub(t2).Abs() + return diff <= margin +} + +func assertTagProvPredsEqual(t *testing.T, actual, expected TagProvenancePred) { + if actual.ActivityType != expected.ActivityType { + t.Errorf("ActivityType %v does not match expected value %v", actual.ActivityType, expected.ActivityType) + } + + if actual.Actor != expected.Actor { + t.Errorf("Actor %v does not match expected value %v", actual.Actor, expected.Actor) + } + + if actual.RepoUri != expected.RepoUri { + t.Errorf("RepoUri %v does not match expected value %v", actual.RepoUri, expected.RepoUri) + } + + if actual.Tag != expected.Tag { + t.Errorf("Tag %v does not match expected value %v", actual.Tag, expected.Tag) + } + + if timesEqualWithinMargin(actual.CreatedOn, expected.CreatedOn, 5*time.Second) { + t.Errorf("CreatedOn %v does not match expected value %v", actual.CreatedOn, expected.CreatedOn) + } + + if len(actual.Controls) != len(expected.Controls) { + t.Errorf("Control %v does not match expected value %v", actual.Controls, expected.Controls) + } else { + for ci, _ := range actual.Controls { + if !timesEqualWithinMargin(actual.Controls[ci].Since, expected.Controls[ci].Since, time.Second) { + t.Errorf("control at [%d]'s time %v does not match expected time %v", ci, + actual.Controls[ci].Since, expected.Controls[ci].Since) + } + } + } + if !reflect.DeepEqual(actual.VsaSummaries, expected.VsaSummaries) { + t.Errorf("VsaSummaries %v does not match expected value %v", actual.VsaSummaries, expected.VsaSummaries) + } +} + +func TestCreateTagProvenance(t *testing.T) { + testVsa := createTestVsa(t, "http://repo", "refs/some/ref", "abc123", slsa_types.SourceVerifiedLevels{"TEST_LEVEL"}) + + ghc := newTestGhConnection("owner", "repo", "branch", + newImmutableTagsRulesetsResponse(123, github.RulesetTargetTag, + github.RulesetEnforcementActive, rulesetOldTime), + newNotesContent(testVsa)) + verifier := testsupport.NewMockVerifier() + + pa := NewProvenanceAttestor(ghc, verifier) + + stmt, err := pa.CreateTagProvenance(context.Background(), "abc123", "refs/tags/v1") + if err != nil { + t.Fatalf("error creating tag prov %v", err) + } + + if stmt == nil { + t.Fatalf("returned statement is nil") + } + + if stmt.PredicateType != TagProvPredicateType { + t.Errorf("statement pred type %v does not match expected %v", stmt.PredicateType, TagProvPredicateType) + } + + if !DoesSubjectIncludeCommit(stmt, "abc123") { + t.Errorf("statement subject %v does not match expected %v", stmt.Subject, "abc123") + } + + tagPred, err := GetTagProvPred(stmt) + if err != nil { + t.Fatalf("error getting tag prov %v", err) + } + + expectedPred := TagProvenancePred{ + RepoUri: "https://github.com/owner/repo", + Actor: "unknown actor", + ActivityType: "unknown activity type", + Tag: "refs/tags/v1", + CreatedOn: rulesetOldTime, + Controls: []slsa_types.Control{ + { + Name: "IMMUTABLE_TAGS", + Since: rulesetOldTime, + }, + }, + VsaSummaries: []VsaSummary{ + { + SourceRefs: []string{"refs/some/ref"}, + VerifiedLevels: []string{"TEST_LEVEL"}, + }, + }, + } + + assertTagProvPredsEqual(t, *tagPred, expectedPred) +} diff --git a/sourcetool/pkg/attest/statement.go b/sourcetool/pkg/attest/statement.go index 6e1b42d4..0190cde9 100644 --- a/sourcetool/pkg/attest/statement.go +++ b/sourcetool/pkg/attest/statement.go @@ -12,17 +12,17 @@ import ( ) type BundleReader struct { - reader *bufio.Reader - verification_options VerificationOptions + reader *bufio.Reader + verifier Verifier } -func NewBundleReader(reader *bufio.Reader, verification_options VerificationOptions) *BundleReader { - return &BundleReader{reader: reader, verification_options: verification_options} +func NewBundleReader(reader *bufio.Reader, verifier Verifier) *BundleReader { + return &BundleReader{reader: reader, verifier: verifier} } func (br BundleReader) convertLineToStatement(line string) (*spb.Statement, error) { // Is this a sigstore bundle with a statement? - vr, err := Verify(line, br.verification_options) + vr, err := br.verifier.Verify(line) if err == nil { // This is it. return vr.Statement, nil diff --git a/sourcetool/pkg/attest/verify.go b/sourcetool/pkg/attest/verify.go index ab0bf36a..34bcf533 100644 --- a/sourcetool/pkg/attest/verify.go +++ b/sourcetool/pkg/attest/verify.go @@ -17,15 +17,31 @@ var DefaultVerifierOptions = VerificationOptions{ ExpectedSan: "https://github.com/slsa-framework/slsa-source-poc/.github/workflows/compute_slsa_source.yml@refs/heads/main", } -func Verify(data string, options VerificationOptions) (*verify.VerificationResult, error) { +type Verifier interface { + Verify(data string) (*verify.VerificationResult, error) +} + +type BndVerifier struct { + Options VerificationOptions +} + +func (bv *BndVerifier) Verify(data string) (*verify.VerificationResult, error) { // TODO: There's more for us to do here... but what? // Maybe check to make sure it's from the identity we expect (the workflow?) verifier := bnd.NewVerifier() - verifier.Options.ExpectedIssuer = options.ExpectedIssuer - verifier.Options.ExpectedSan = options.ExpectedSan + verifier.Options.ExpectedIssuer = bv.Options.ExpectedIssuer + verifier.Options.ExpectedSan = bv.Options.ExpectedSan vr, err := verifier.VerifyInlineBundle([]byte(data)) if err != nil { return nil, err } return vr, nil } + +func NewBndVerifier(options VerificationOptions) *BndVerifier { + return &BndVerifier{Options: options} +} + +func GetDefaultVerifier() Verifier { + return NewBndVerifier(DefaultVerifierOptions) +} diff --git a/sourcetool/pkg/attest/vsa.go b/sourcetool/pkg/attest/vsa.go index 2e654525..864459c3 100644 --- a/sourcetool/pkg/attest/vsa.go +++ b/sourcetool/pkg/attest/vsa.go @@ -18,8 +18,8 @@ import ( const VsaPredicateType = "https://slsa.dev/verification_summary/v1" -func CreateUnsignedSourceVsa(gh_connection *gh_control.GitHubConnection, commit string, verifiedLevels slsa_types.SourceVerifiedLevels, policy string) (string, error) { - resourceUri := fmt.Sprintf("git+%s", gh_connection.GetRepoUri()) +func CreateUnsignedSourceVsa(repoUri, ref, commit string, verifiedLevels slsa_types.SourceVerifiedLevels, policy string) (string, error) { + resourceUri := fmt.Sprintf("git+%s", repoUri) vsaPred := &vpb.VerificationSummary{ Verifier: &vpb.VerificationSummary_Verifier{ Id: "https://github.com/slsa-framework/slsa-source-poc"}, @@ -35,7 +35,8 @@ func CreateUnsignedSourceVsa(gh_connection *gh_control.GitHubConnection, commit return "", err } - branchAnnotation := map[string]any{"source_branches": []any{gh_connection.GetFullBranch()}} + // TODO: update to source_refs to match updated spec. + branchAnnotation := map[string]any{"source_branches": []any{ref}} annotationStruct, err := structpb.NewStruct(branchAnnotation) if err != nil { return "", fmt.Errorf("creating struct from map: %w", err) @@ -66,8 +67,8 @@ func CreateUnsignedSourceVsa(gh_connection *gh_control.GitHubConnection, commit } // Gets provenance for the commit from git notes. -func (pa ProvenanceAttestor) GetVsa(ctx context.Context, commit string) (*spb.Statement, *vpb.VerificationSummary, error) { - notes, err := pa.gh_connection.GetNotesForCommit(ctx, commit) +func GetVsa(ctx context.Context, ghc *gh_control.GitHubConnection, verifier Verifier, commit, ref string) (*spb.Statement, *vpb.VerificationSummary, error) { + notes, err := ghc.GetNotesForCommit(ctx, commit) if notes == "" { log.Printf("didn't find notes for commit %s", commit) return nil, nil, nil @@ -76,7 +77,20 @@ func (pa ProvenanceAttestor) GetVsa(ctx context.Context, commit string) (*spb.St if err != nil { log.Fatal(err) } - return pa.getVsaFromReader(NewBundleReader(bufio.NewReader(strings.NewReader(notes)), pa.verification_options), commit) + return getVsaFromReader(NewBundleReader(bufio.NewReader(strings.NewReader(notes)), verifier), commit, ref) +} + +func GetSourceRefsForCommit(vsaStatement *spb.Statement, commit string) ([]string, error) { + subject := GetSubjectForCommit(vsaStatement, commit) + if subject == nil { + return []string{}, fmt.Errorf("statement \n%v\n does not match commit %s", StatementToString(vsaStatement), commit) + } + protoRefs := subject.GetAnnotations().Fields["source_branches"].GetListValue() + stringRefs := []string{} + for _, ref := range protoRefs.Values { + stringRefs = append(stringRefs, ref.GetStringValue()) + } + return stringRefs, nil } func getVsaPred(statement *spb.Statement) (*vpb.VerificationSummary, error) { @@ -94,32 +108,31 @@ func getVsaPred(statement *spb.Statement) (*vpb.VerificationSummary, error) { return &predStruct, nil } -func MatchesTypeCommitAndBranch(predicateType, commit, targetBranch string) StatementMatcher { +func MatchesTypeCommitAndRef(predicateType, commit, targetRef string) StatementMatcher { return func(statement *spb.Statement) bool { if statement.PredicateType != predicateType { log.Printf("statement predicate type (%s) doesn't match %s", statement.PredicateType, predicateType) return false } - subject := GetSubjectForCommit(statement, commit) - if subject == nil { - log.Printf("statement \n%v\n does not match commit %s", StatementToString(statement), commit) + refs, err := GetSourceRefsForCommit(statement, commit) + if err != nil { + log.Printf("statement \n%v\n does not match commit %s: %v", StatementToString(statement), commit, err) return false } - branches := subject.GetAnnotations().Fields["source_branches"].GetListValue() - for _, branch := range branches.Values { - if branch.GetStringValue() == targetBranch { - log.Printf("statement \n%v\n matches commit '%s' on branch '%s'", StatementToString(statement), commit, targetBranch) + for _, ref := range refs { + if targetRef == gh_control.AnyReference || ref == targetRef { + log.Printf("statement \n%v\n matches commit '%s' on ref '%s'", StatementToString(statement), commit, targetRef) return true } } - log.Printf("source_branches (%v) in VSA does not contain %s", branches, targetBranch) + log.Printf("source_branches (%v) in VSA does not contain %s", refs, targetRef) return false } } -func (pa ProvenanceAttestor) getVsaFromReader(reader *BundleReader, commit string) (*spb.Statement, *vpb.VerificationSummary, error) { +func getVsaFromReader(reader *BundleReader, commit, ref string) (*spb.Statement, *vpb.VerificationSummary, error) { for { - stmt, err := reader.ReadStatement(MatchesTypeCommitAndBranch(VsaPredicateType, commit, pa.gh_connection.GetFullBranch())) + stmt, err := reader.ReadStatement(MatchesTypeCommitAndRef(VsaPredicateType, commit, ref)) if err != nil { // Ignore errors, we want to check all the lines. log.Printf("error while processing line: %v", err) @@ -139,6 +152,6 @@ func (pa ProvenanceAttestor) getVsaFromReader(reader *BundleReader, commit strin return stmt, vsaPred, nil } - log.Printf("didn't find commit %s for branch %s", commit, pa.gh_connection.Branch) + log.Printf("didn't find commit %s for ref %s", commit, ref) return nil, nil, nil } diff --git a/sourcetool/pkg/gh_control/checklevel.go b/sourcetool/pkg/gh_control/checklevel.go index a051aa98..3392489c 100644 --- a/sourcetool/pkg/gh_control/checklevel.go +++ b/sourcetool/pkg/gh_control/checklevel.go @@ -25,21 +25,20 @@ type activity struct { Actor actor `json:"actor"` } -func (ghc *GitHubConnection) commitActivity(ctx context.Context, commit string) (*activity, error) { +func (ghc *GitHubConnection) commitActivity(ctx context.Context, commit, targetRef string) (*activity, error) { // Unfortunately the gh_client doesn't have native support for this...' - reqUrl := fmt.Sprintf("repos/%s/%s/activity", ghc.Owner, ghc.Repo) - req, err := ghc.Client.NewRequest("GET", reqUrl, nil) + reqUrl := fmt.Sprintf("repos/%s/%s/activity", ghc.Owner(), ghc.Repo()) + req, err := ghc.Client().NewRequest("GET", reqUrl, nil) if err != nil { return nil, err } var result []*activity - _, err = ghc.Client.Do(ctx, req, &result) + _, err = ghc.Client().Do(ctx, req, &result) if err != nil { return nil, err } - targetRef := ghc.GetFullBranch() monitoredTypes := []string{"push", "force_push", "pr_merge"} for _, activity := range result { if !slices.Contains(monitoredTypes, activity.ActivityType) { @@ -51,7 +50,7 @@ func (ghc *GitHubConnection) commitActivity(ctx context.Context, commit string) } } - return nil, fmt.Errorf("could not find repo activity for commit %s", commit) + return nil, fmt.Errorf("could not find repo activity for commit %s and ref %s", commit, targetRef) } type GhControlStatus struct { @@ -109,6 +108,7 @@ func enforcesImmutableTags(ruleset *github.RepositoryRuleset) bool { ruleset.Rules.Update != nil && ruleset.Rules.Deletion != nil && ruleset.Rules.NonFastForward != nil && + ruleset.Conditions != nil && len(ruleset.Conditions.RefName.Exclude) == 0 && slices.Contains(ruleset.Conditions.RefName.Include, "~ALL") { return true @@ -116,7 +116,7 @@ func enforcesImmutableTags(ruleset *github.RepositoryRuleset) bool { return false } -func (ghc *GitHubConnection) computeImmutableTagsControl(ctx context.Context, commit string, allRulesets []*github.RepositoryRuleset, activity *activity) (*slsa_types.Control, error) { +func (ghc *GitHubConnection) computeImmutableTagsControl(ctx context.Context, commit string, allRulesets []*github.RepositoryRuleset, activityTime *time.Time) (*slsa_types.Control, error) { var validRuleset *github.RepositoryRuleset for _, ruleset := range allRulesets { if *ruleset.Target != github.RulesetTargetTag { @@ -129,9 +129,9 @@ func (ghc *GitHubConnection) computeImmutableTagsControl(ctx context.Context, co // The GitHub API only seems to return a partial ruleset when asking for 'all' the rules // So we'll ask for this specific rule here so we can get all the data. - fullRuleset, _, err := ghc.Client.Repositories.GetRuleset(ctx, ghc.Owner, ghc.Repo, ruleset.GetID(), false) + fullRuleset, _, err := ghc.Client().Repositories.GetRuleset(ctx, ghc.Owner(), ghc.Repo(), ruleset.GetID(), false) if err != nil { - return nil, fmt.Errorf("could not get full ruleset for ruleset id %d", ruleset.GetID()) + return nil, fmt.Errorf("could not get full ruleset for ruleset id %d: err: %w", ruleset.GetID(), err) } if !enforcesImmutableTags(fullRuleset) { @@ -147,7 +147,7 @@ func (ghc *GitHubConnection) computeImmutableTagsControl(ctx context.Context, co } // Check that the commit was created after this rule was enabled. - if activity.Timestamp.Before(validRuleset.UpdatedAt.Time) { + if activityTime.Before(validRuleset.UpdatedAt.Time) { return nil, nil } @@ -159,7 +159,7 @@ func (ghc *GitHubConnection) computeReviewControl(ctx context.Context, rules []* var oldestActive *github.RepositoryRuleset for _, rule := range rules { if ghc.ruleMeetsRequiresReview(rule) { - ruleset, _, err := ghc.Client.Repositories.GetRuleset(ctx, ghc.Owner, ghc.Repo, rule.RulesetID, false) + ruleset, _, err := ghc.Client().Repositories.GetRuleset(ctx, ghc.Owner(), ghc.Repo(), rule.RulesetID, false) if err != nil { return nil, err } @@ -181,7 +181,7 @@ func (ghc *GitHubConnection) computeReviewControl(ctx context.Context, rules []* func (ghc *GitHubConnection) getOldestActiveRule(ctx context.Context, rules []*github.BranchRuleMetadata) (*github.RepositoryRuleset, error) { var oldestActive *github.RepositoryRuleset for _, rule := range rules { - ruleset, _, err := ghc.Client.Repositories.GetRuleset(ctx, ghc.Owner, ghc.Repo, rule.RulesetID, false) + ruleset, _, err := ghc.Client().Repositories.GetRuleset(ctx, ghc.Owner(), ghc.Repo(), rule.RulesetID, false) if err != nil { return nil, err } @@ -194,11 +194,11 @@ func (ghc *GitHubConnection) getOldestActiveRule(ctx context.Context, rules []*g return oldestActive, nil } -// Determines the controls that are in place using GitHub's APIs. +// Determines the controls that are in place for a branch using GitHub's APIs // This is necessarily only as good as GitHub's controls and existing APIs. -func (ghc *GitHubConnection) GetControls(ctx context.Context, commit string) (*GhControlStatus, error) { +func (ghc *GitHubConnection) GetBranchControls(ctx context.Context, commit, ref string) (*GhControlStatus, error) { // We want to know when this commit was pushed to ensure the rules were active _then_. - activity, err := ghc.commitActivity(ctx, commit) + activity, err := ghc.commitActivity(ctx, commit, ref) if err != nil { return nil, err } @@ -209,16 +209,15 @@ func (ghc *GitHubConnection) GetControls(ctx context.Context, commit string) (*G ActorLogin: activity.Actor.Login, Controls: slsa_types.Controls{}} - branchRules, _, err := ghc.Client.Repositories.GetRulesForBranch(ctx, ghc.Owner, ghc.Repo, ghc.Branch) - if err != nil { - return nil, err + branch := GetBranchFromRef(ref) + if branch == "" { + return nil, fmt.Errorf("ref %s is not a branch", ref) } - - allRulesets, _, err := ghc.Client.Repositories.GetAllRulesets(ctx, ghc.Owner, ghc.Repo, true) + // Do the branch specific stuff. + branchRules, _, err := ghc.Client().Repositories.GetRulesForBranch(ctx, ghc.Owner(), ghc.Repo(), branch) if err != nil { return nil, err } - // Compute the controls enforced. continuityControl, err := ghc.computeContinuityControl(ctx, commit, branchRules, activity) if err != nil { @@ -226,17 +225,41 @@ func (ghc *GitHubConnection) GetControls(ctx context.Context, commit string) (*G } controlStatus.Controls.AddControl(continuityControl) - ImmutableTagsControl, err := ghc.computeImmutableTagsControl(ctx, commit, allRulesets, activity) + reviewControl, err := ghc.computeReviewControl(ctx, branchRules.PullRequest) + if err != nil { + return nil, fmt.Errorf("could not populate ReviewControl: %w", err) + } + controlStatus.Controls.AddControl(reviewControl) + + allRulesets, _, err := ghc.Client().Repositories.GetAllRulesets(ctx, ghc.Owner(), ghc.Repo(), true) + if err != nil { + return nil, err + } + ImmutableTagsControl, err := ghc.computeImmutableTagsControl(ctx, commit, allRulesets, &activity.Timestamp) if err != nil { return nil, fmt.Errorf("could not populate ImmutableTagsControl: %w", err) } controlStatus.Controls.AddControl(ImmutableTagsControl) - reviewControl, err := ghc.computeReviewControl(ctx, branchRules.PullRequest) + return &controlStatus, nil +} + +func (ghc *GitHubConnection) GetTagControls(ctx context.Context, commit, ref string) (*GhControlStatus, error) { + controlStatus := GhControlStatus{ + CommitPushTime: time.Now(), + ActivityType: "unknown activity type", + ActorLogin: "unknown actor", + Controls: slsa_types.Controls{}} + + allRulesets, _, err := ghc.Client().Repositories.GetAllRulesets(ctx, ghc.Owner(), ghc.Repo(), true) if err != nil { - return nil, fmt.Errorf("could not populate ReviewControl: %w", err) + return nil, err } - controlStatus.Controls.AddControl(reviewControl) + ImmutableTagsControl, err := ghc.computeImmutableTagsControl(ctx, commit, allRulesets, &controlStatus.CommitPushTime) + if err != nil { + return nil, fmt.Errorf("could not populate ImmutableTagsControl: %w", err) + } + controlStatus.Controls.AddControl(ImmutableTagsControl) return &controlStatus, nil } diff --git a/sourcetool/pkg/gh_control/connection.go b/sourcetool/pkg/gh_control/connection.go index 7c8c5933..10968ec6 100644 --- a/sourcetool/pkg/gh_control/connection.go +++ b/sourcetool/pkg/gh_control/connection.go @@ -7,51 +7,59 @@ import ( "github.com/google/go-github/v69/github" ) +// Manages a connection to a GitHub repository. type GitHubConnection struct { - Client *github.Client - Owner, Repo, Branch string + client *github.Client + owner, repo, ref string } -func NewGhConnection(owner, repo, branch string) *GitHubConnection { +func NewGhConnection(owner, repo, ref string) *GitHubConnection { + return NewGhConnectionWithClient(owner, repo, ref, github.NewClient(nil)) +} + +func NewGhConnectionWithClient(owner, repo, ref string, client *github.Client) *GitHubConnection { return &GitHubConnection{ - Client: github.NewClient(nil), - Owner: owner, - Repo: repo, - Branch: branch} + client: client, + owner: owner, + repo: repo, + ref: ref} +} + +func (ghc *GitHubConnection) Client() *github.Client { + return ghc.client +} + +func (ghc *GitHubConnection) Owner() string { + return ghc.owner +} + +func (ghc *GitHubConnection) Repo() string { + return ghc.repo +} + +func (ghc *GitHubConnection) GetFullRef() string { + return ghc.ref } // Uses the provide token for auth. // If the token is the empty string this is a no-op. func (ghc *GitHubConnection) WithAuthToken(token string) *GitHubConnection { if token != "" { - ghc.Client = ghc.Client.WithAuthToken(token) + ghc.client = ghc.client.WithAuthToken(token) } return ghc } -// Returns the fully qualified branch (e.g. 'refs/heads/main'). -func (ghc *GitHubConnection) GetFullBranch() string { - return fmt.Sprintf("refs/heads/%s", ghc.Branch) -} - // Returns the URI of the repo this connection tracks. func (ghc *GitHubConnection) GetRepoUri() string { - return fmt.Sprintf("https://github.com/%s/%s", ghc.Owner, ghc.Repo) -} - -func (ghc *GitHubConnection) GetLatestCommit(ctx context.Context) (string, error) { - branch, _, err := ghc.Client.Repositories.GetBranch(ctx, ghc.Owner, ghc.Repo, ghc.Branch, 1) - if err != nil { - return "", fmt.Errorf("could not get info on specified branch %s: %w", ghc.Branch, err) - } - return *branch.Commit.SHA, nil + return fmt.Sprintf("https://github.com/%s/%s", ghc.Owner(), ghc.Repo()) } // Gets the previous commit to 'sha' if it has one. // If there are more than one parents this fails with an error. // (This tool generally operates in an environment of linear history) func (ghc *GitHubConnection) GetPriorCommit(ctx context.Context, sha string) (string, error) { - commit, _, err := ghc.Client.Git.GetCommit(ctx, ghc.Owner, ghc.Repo, sha) + commit, _, err := ghc.Client().Git.GetCommit(ctx, ghc.Owner(), ghc.Repo(), sha) if err != nil { return "", fmt.Errorf("cannot get commit data for %s: %w", sha, err) } @@ -66,3 +74,11 @@ func (ghc *GitHubConnection) GetPriorCommit(ctx context.Context, sha string) (st return *commit.Parents[0].SHA, nil } + +func (ghc *GitHubConnection) GetLatestCommit(ctx context.Context, targetBranch string) (string, error) { + branch, _, err := ghc.Client().Repositories.GetBranch(ctx, ghc.Owner(), ghc.Repo(), targetBranch, 1) + if err != nil { + return "", fmt.Errorf("could not get info on specified branch %s: %w", targetBranch, err) + } + return *branch.Commit.SHA, nil +} diff --git a/sourcetool/pkg/gh_control/git_types.go b/sourcetool/pkg/gh_control/git_types.go new file mode 100644 index 00000000..f82561ab --- /dev/null +++ b/sourcetool/pkg/gh_control/git_types.go @@ -0,0 +1,26 @@ +package gh_control + +import ( + "fmt" + "strings" +) + +// Matches any reference type. +const AnyReference = "*" + +func BranchToFullRef(branch string) string { + return fmt.Sprintf("refs/heads/%s", branch) +} + +func TagToFullRef(tag string) string { + return fmt.Sprintf("refs/tags/%s", tag) +} + +// Returns "" if the ref isn't a branch +func GetBranchFromRef(ref string) string { + return strings.TrimPrefix(ref, "refs/heads/") +} + +func GetTagFromRef(ref string) string { + return strings.TrimPrefix(ref, "refs/tags/") +} diff --git a/sourcetool/pkg/gh_control/notes.go b/sourcetool/pkg/gh_control/notes.go index 2da46666..e5e93a32 100644 --- a/sourcetool/pkg/gh_control/notes.go +++ b/sourcetool/pkg/gh_control/notes.go @@ -12,8 +12,8 @@ func (ghc *GitHubConnection) GetNotesForCommit(ctx context.Context, commit strin // We can find the notes for a given commit fairly easily. // They'll be in the path within ref `refs/notes/commits` - contents, _, resp, err := ghc.Client.Repositories.GetContents( - ctx, ghc.Owner, ghc.Repo, commit, &github.RepositoryContentGetOptions{Ref: "refs/notes/commits"}) + contents, _, resp, err := ghc.Client().Repositories.GetContents( + ctx, ghc.Owner(), ghc.Repo(), commit, &github.RepositoryContentGetOptions{Ref: "refs/notes/commits"}) if resp.StatusCode == http.StatusNotFound { // Don't freak out if it's not there. @@ -22,6 +22,10 @@ func (ghc *GitHubConnection) GetNotesForCommit(ctx context.Context, commit strin if err != nil { return "", fmt.Errorf("cannot get note contents for commit %s: %w", commit, err) } + if contents == nil { + // No notes stored for this commit. + return "", nil + } return contents.GetContent() } diff --git a/sourcetool/pkg/policy/policy.go b/sourcetool/pkg/policy/policy.go index 0f8d33f2..044fed5f 100644 --- a/sourcetool/pkg/policy/policy.go +++ b/sourcetool/pkg/policy/policy.go @@ -32,17 +32,41 @@ type ProtectedBranch struct { Since time.Time TargetSlsaSourceLevel slsa_types.SlsaSourceLevel `json:"target_slsa_source_level"` RequireReview bool `json:"require_review"` - ImmutableTags bool `json:"immutable_tags"` +} + +// The controls required for protected tags. +type ProtectedTag struct { + Since time.Time + ImmutableTags bool `json:"immutable_tags"` } type RepoPolicy struct { - // I'm actually not sure we need this. Consider removing? + // TODO: I'm actually not sure we need this. Consider removing? CanonicalRepo string `json:"canonical_repo"` ProtectedBranches []ProtectedBranch `json:"protected_branches"` + ProtectedTag *ProtectedTag `json:"protected_tag"` +} + +// Returns the policy for the branch or nil if the branch doesn't have one. +func (rp *RepoPolicy) getBranchPolicy(branch string) *ProtectedBranch { + for _, pb := range rp.ProtectedBranches { + if pb.Name == branch { + return &pb + } + } + return nil +} + +func createDefaultBranchPolicy(branch string) *ProtectedBranch { + return &ProtectedBranch{ + Name: branch, + Since: time.Now(), + TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, + RequireReview: false} } func getPolicyPath(gh_connection *gh_control.GitHubConnection) string { - return fmt.Sprintf("policy/github.com/%s/%s/source-policy.json", gh_connection.Owner, gh_connection.Repo) + return fmt.Sprintf("policy/github.com/%s/%s/source-policy.json", gh_connection.Owner(), gh_connection.Repo()) } func getPolicyRepoPath(pathToClone string, gh_connection *gh_control.GitHubConnection) string { @@ -53,7 +77,7 @@ func getPolicyRepoPath(pathToClone string, gh_connection *gh_control.GitHubConne func getRemotePolicy(ctx context.Context, gh_connection *gh_control.GitHubConnection) (*RepoPolicy, string, error) { path := getPolicyPath(gh_connection) - policyContents, _, resp, err := gh_connection.Client.Repositories.GetContents(ctx, SourcePolicyRepoOwner, SourcePolicyRepo, path, nil) + policyContents, _, resp, err := gh_connection.Client().Repositories.GetContents(ctx, SourcePolicyRepoOwner, SourcePolicyRepo, path, nil) if resp != nil && resp.StatusCode == http.StatusNotFound { return nil, "", nil } @@ -88,35 +112,11 @@ func getLocalPolicy(path string) (*RepoPolicy, string, error) { return &p, path, nil } -func (policy Policy) getPolicy(ctx context.Context, gh_connection *gh_control.GitHubConnection) (*RepoPolicy, string, error) { - if policy.UseLocalPolicy == "" { +func (pe PolicyEvaluator) getPolicy(ctx context.Context, gh_connection *gh_control.GitHubConnection) (*RepoPolicy, string, error) { + if pe.UseLocalPolicy == "" { return getRemotePolicy(ctx, gh_connection) } - return getLocalPolicy(policy.UseLocalPolicy) -} - -// Gets the policy for the indicated branch direct from the GitHub repo. -func (policy Policy) getBranchPolicy(ctx context.Context, gh_connection *gh_control.GitHubConnection) (*ProtectedBranch, string, error) { - p, path, err := policy.getPolicy(ctx, gh_connection) - - if err != nil { - return nil, "", err - } - - if p != nil { - for _, pb := range p.ProtectedBranches { - if pb.Name == gh_connection.Branch { - return &pb, path, nil - } - } - } - - // No policy so return the default branch policy. - return &ProtectedBranch{ - Name: gh_connection.Branch, - Since: time.Now(), - TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, - RequireReview: false}, "DEFAULT", nil + return getLocalPolicy(pe.UseLocalPolicy) } // Check to see if the local directory is a clean clone or not @@ -169,14 +169,17 @@ func CreateLocalPolicy(ctx context.Context, gh_connection *gh_control.GitHubConn path := getPolicyRepoPath(pathToClone, gh_connection) // What's their latest commit (needed for checking control status) - latestCommit, err := gh_connection.GetLatestCommit(ctx) + branch := gh_control.GetBranchFromRef(gh_connection.GetFullRef()) + if branch == "" { + return "", fmt.Errorf("cannot create local policy, ref %s isn't a branch", gh_connection.GetFullRef()) + } + latestCommit, err := gh_connection.GetLatestCommit(ctx, branch) if err != nil { return "", fmt.Errorf("could not get latest commit: %w", err) } - ver_options := attest.DefaultVerifierOptions - pa := attest.NewProvenanceAttestor(gh_connection, ver_options) - _, provPred, err := pa.GetProvenance(ctx, latestCommit) + pa := attest.NewProvenanceAttestor(gh_connection, attest.GetDefaultVerifier()) + _, provPred, err := pa.GetProvenance(ctx, latestCommit, gh_connection.GetFullRef()) if err != nil { return "", fmt.Errorf("could not get provenance for latest commit: %w", err) } @@ -199,7 +202,7 @@ func CreateLocalPolicy(ctx context.Context, gh_connection *gh_control.GitHubConn CanonicalRepo: "TODO fill this in", ProtectedBranches: []ProtectedBranch{ { - Name: gh_connection.Branch, + Name: branch, Since: *eligibleSince, TargetSlsaSourceLevel: eligibleLevel, // TODO support filling in other controls too. @@ -315,8 +318,13 @@ func computeReviewEnforced(branchPolicy *ProtectedBranch, controls slsa_types.Co return true, nil } -func computeImmutableTags(branchPolicy *ProtectedBranch, controls slsa_types.Controls) (bool, error) { - if !branchPolicy.ImmutableTags { +func computeImmutableTags(tagPolicy *ProtectedTag, controls slsa_types.Controls) (bool, error) { + if tagPolicy == nil { + // There is no tag policy, so the control isn't met, but it's not an error. + return false, nil + } + + if !tagPolicy.ImmutableTags { return false, nil } @@ -325,15 +333,15 @@ func computeImmutableTags(branchPolicy *ProtectedBranch, controls slsa_types.Con return false, fmt.Errorf("policy requires immutable tags, but that control is not enabled") } - if branchPolicy.Since.Before(immutableTags.Since) { - return false, fmt.Errorf("policy requires immutable tags since %v, but that control has only been enabled since %v", branchPolicy.Since, immutableTags.Since) + if tagPolicy.Since.Before(immutableTags.Since) { + return false, fmt.Errorf("policy requires immutable tags since %v, but that control has only been enabled since %v", tagPolicy.Since, immutableTags.Since) } return true, nil } -// Returns a list of controls to include in the vsa's 'verifiedLevels' field. -func evaluateControls(branchPolicy *ProtectedBranch, controls slsa_types.Controls) (slsa_types.SourceVerifiedLevels, error) { +// Returns a list of controls to include in the vsa's 'verifiedLevels' field when creating a VSA for a branch. +func evaluateBranchControls(branchPolicy *ProtectedBranch, tagPolicy *ProtectedTag, controls slsa_types.Controls) (slsa_types.SourceVerifiedLevels, error) { slsaSourceLevel, err := computeSlsaLevel(branchPolicy, controls) if err != nil { return slsa_types.SourceVerifiedLevels{}, fmt.Errorf("error computing slsa level: %w", err) @@ -349,7 +357,7 @@ func evaluateControls(branchPolicy *ProtectedBranch, controls slsa_types.Control verifiedLevels = append(verifiedLevels, slsa_types.ReviewEnforced) } - immutableTags, err := computeImmutableTags(branchPolicy, controls) + immutableTags, err := computeImmutableTags(tagPolicy, controls) if err != nil { return slsa_types.SourceVerifiedLevels{}, fmt.Errorf("error computing tag immutability enforced: %w", err) } @@ -360,31 +368,59 @@ func evaluateControls(branchPolicy *ProtectedBranch, controls slsa_types.Control return verifiedLevels, nil } -type Policy struct { +// Returns a list of controls to include in the vsa's 'verifiedLevels' field when creating a VSA for a tag. +// Users provide a list of verifiedLevels that came from VSAs issued previously for the commit pointed to by this +// tag. +func evaluateTagProv(tagPolicy *ProtectedTag, tagProvPred *attest.TagProvenancePred) (slsa_types.SourceVerifiedLevels, error) { + // As long as all the controls for tag protection are currently in force then we'll + // include the verifiedLevels. + + // TODO: handle tag policy? + immutableTags, err := computeImmutableTags(tagPolicy, tagProvPred.Controls) + if err != nil { + return slsa_types.SourceVerifiedLevels{}, fmt.Errorf("error computing tag immutability enforced: %w", err) + } + if immutableTags { + // TODO: should we include the immutable tag field specifically? + return tagProvPred.VsaSummaries[0].VerifiedLevels, nil + } + + // If tag immutability isn't enabled then we just return level 1. + return slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel1)}, nil +} + +type PolicyEvaluator struct { // UNSAFE! // Instead of grabbing the policy from the canonical repo, use the policy at this path instead. UseLocalPolicy string } -func NewPolicy() *Policy { - return &Policy{} +func NewPolicyEvaluator() *PolicyEvaluator { + return &PolicyEvaluator{} } // Evaluates the control against the policy and returns the resulting source level and policy path. -func (policy Policy) EvaluateControl(ctx context.Context, gh_connection *gh_control.GitHubConnection, controlStatus *gh_control.GhControlStatus) (slsa_types.SourceVerifiedLevels, string, error) { +func (pe PolicyEvaluator) EvaluateControl(ctx context.Context, gh_connection *gh_control.GitHubConnection, controlStatus *gh_control.GhControlStatus) (slsa_types.SourceVerifiedLevels, string, error) { // We want to check to ensure the repo hasn't enabled/disabled the rules since // setting the 'since' field in their policy. - branchPolicy, policyPath, err := policy.getBranchPolicy(ctx, gh_connection) + rp, policyPath, err := pe.getPolicy(ctx, gh_connection) if err != nil { return slsa_types.SourceVerifiedLevels{}, "", err } + branch := gh_control.GetBranchFromRef(gh_connection.GetFullRef()) + branchPolicy := rp.getBranchPolicy(branch) + if branchPolicy == nil { + branchPolicy = createDefaultBranchPolicy(branch) + policyPath = "DEFAULT" + } + if controlStatus.CommitPushTime.Before(branchPolicy.Since) { // This commit was pushed before they had an explicit policy. return slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel1)}, policyPath, nil } - verifiedLevels, err := evaluateControls(branchPolicy, controlStatus.Controls) + verifiedLevels, err := evaluateBranchControls(branchPolicy, rp.ProtectedTag, controlStatus.Controls) if err != nil { return verifiedLevels, policyPath, fmt.Errorf("error evaluating policy %s: %w", policyPath, err) } @@ -392,18 +428,25 @@ func (policy Policy) EvaluateControl(ctx context.Context, gh_connection *gh_cont } // Evaluates the provenance against the policy and returns the resulting source level and policy path -func (policy Policy) EvaluateProv(ctx context.Context, gh_connection *gh_control.GitHubConnection, prov *spb.Statement) (slsa_types.SourceVerifiedLevels, string, error) { - branchPolicy, policyPath, err := policy.getBranchPolicy(ctx, gh_connection) +func (pe PolicyEvaluator) EvaluateSourceProv(ctx context.Context, gh_connection *gh_control.GitHubConnection, prov *spb.Statement) (slsa_types.SourceVerifiedLevels, string, error) { + rp, policyPath, err := pe.getPolicy(ctx, gh_connection) if err != nil { return slsa_types.SourceVerifiedLevels{}, "", err } - provPred, err := attest.GetProvPred(prov) + provPred, err := attest.GetSourceProvPred(prov) if err != nil { return slsa_types.SourceVerifiedLevels{}, "", err } - verifiedLevels, err := evaluateControls(branchPolicy, provPred.Controls) + branch := gh_control.GetBranchFromRef(gh_connection.GetFullRef()) + branchPolicy := rp.getBranchPolicy(branch) + if branchPolicy == nil { + branchPolicy = createDefaultBranchPolicy(branch) + policyPath = "DEFAULT" + } + + verifiedLevels, err := evaluateBranchControls(branchPolicy, rp.ProtectedTag, provPred.Controls) if err != nil { return slsa_types.SourceVerifiedLevels{}, policyPath, fmt.Errorf("error evaluating policy %s: %w", policyPath, err) } @@ -411,3 +454,25 @@ func (policy Policy) EvaluateProv(ctx context.Context, gh_connection *gh_control // Looks good! return verifiedLevels, policyPath, nil } + +// Evaluates the provenance against the policy and returns the resulting source level and policy path +func (pe PolicyEvaluator) EvaluateTagProv(ctx context.Context, gh_connection *gh_control.GitHubConnection, prov *spb.Statement) (slsa_types.SourceVerifiedLevels, string, error) { + rp, policyPath, err := pe.getPolicy(ctx, gh_connection) + if err != nil { + return slsa_types.SourceVerifiedLevels{}, "", err + } + + provPred, err := attest.GetTagProvPred(prov) + if err != nil { + return slsa_types.SourceVerifiedLevels{}, "", err + } + + // TODO: get the levels we want to use from the prov predicate... + outputVerifiedLevels, err := evaluateTagProv(rp.ProtectedTag, provPred) + if err != nil { + return slsa_types.SourceVerifiedLevels{}, policyPath, fmt.Errorf("error evaluating policy %s: %w", policyPath, err) + } + + // Looks good! + return outputVerifiedLevels, policyPath, nil +} diff --git a/sourcetool/pkg/policy/policy_test.go b/sourcetool/pkg/policy/policy_test.go index aca5ff66..480e9b23 100644 --- a/sourcetool/pkg/policy/policy_test.go +++ b/sourcetool/pkg/policy/policy_test.go @@ -24,6 +24,33 @@ import ( "github.com/slsa-framework/slsa-source-poc/sourcetool/pkg/slsa_types" ) +var fixedTime = time.Unix(1678886400, 0) // March 15, 2023 00:00:00 UTC +var earlierFixedTime = fixedTime.Add(-time.Hour) +var laterFixedTime = fixedTime.Add(time.Hour) + +func createTestBranchPolicy(branch string) ProtectedBranch { + return ProtectedBranch{ + Name: branch, + Since: fixedTime, + TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, + RequireReview: true, + } +} + +func createTestPolicy(pb ProtectedBranch) RepoPolicy { + return RepoPolicy{ + CanonicalRepo: "the-canonical-repo", + ProtectedBranches: []ProtectedBranch{ + pb, + }, + ProtectedTag: &ProtectedTag{ + Since: fixedTime, + ImmutableTags: true, + }, + } + +} + // Helper to create spb.Statement - moved here to be accessible by new test functions func createStatementForTest(t *testing.T, predicateContent interface{}, predType string) *spb.Statement { t.Helper() @@ -47,6 +74,11 @@ func createStatementForTest(t *testing.T, predicateContent interface{}, predType } } +// Helper to create a test GH Branch connection with no client. +func newTestGhBranchConnection(owner, repo, branch string) *gh_control.GitHubConnection { + return gh_control.NewGhConnectionWithClient(owner, repo, gh_control.BranchToFullRef(branch), nil) +} + // createTempPolicyFile creates a temporary file with the given policy data. // If policyData is a RepoPolicy, it's marshalled to JSON. // If policyData is a string, it's written directly. @@ -82,7 +114,7 @@ func createTempPolicyFile(t *testing.T, policyData interface{}) string { func validateMockServerRequestPath(t *testing.T, r *http.Request, expectedPolicyOwner, expectedPolicyRepo, expectedPolicyBranch string) { t.Helper() // This ghConn is only for generating the policy file path segment based on the target repo's details - tempGhConn := &gh_control.GitHubConnection{Owner: expectedPolicyOwner, Repo: expectedPolicyRepo, Branch: expectedPolicyBranch} + tempGhConn := newTestGhBranchConnection(expectedPolicyOwner, expectedPolicyRepo, expectedPolicyBranch) policyFilePathSegment := getPolicyPath(tempGhConn) // getPolicyPath is an existing function in the policy package // Construct the full expected API path suffix for the GetContents call @@ -96,15 +128,12 @@ func validateMockServerRequestPath(t *testing.T, r *http.Request, expectedPolicy } } -func TestEvaluateProv_Success(t *testing.T) { - now := time.Now() - earlier := now.Add(-time.Hour) - +func TestEvaluateSourceProv_Success(t *testing.T) { // Controls for mock provenance - continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlier} - provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlier} - reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlier} - immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlier} + continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlierFixedTime} + provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlierFixedTime} + reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlierFixedTime} + immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlierFixedTime} // Valid Provenance Predicate (attest.SourceProvenancePred) validProvPredicateL3Controls := attest.SourceProvenancePred{ @@ -113,50 +142,55 @@ func TestEvaluateProv_Success(t *testing.T) { provenanceStatement := createStatementForTest(t, validProvPredicateL3Controls, attest.SourceProvPredicateType) - expectedPolicyFilePath := createTempPolicyFile(t, RepoPolicy{ - ProtectedBranches: []ProtectedBranch{ - {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: true, Since: now}, - }, - }) + pb := ProtectedBranch{ + Name: "main", + TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, + RequireReview: true, + Since: fixedTime, + } + rp := createTestPolicy(pb) + rp.ProtectedTag.Since = fixedTime + rp.ProtectedTag.ImmutableTags = true + + expectedPolicyFilePath := createTempPolicyFile(t, rp) defer os.Remove(expectedPolicyFilePath) - p := &Policy{UseLocalPolicy: expectedPolicyFilePath} + pe := &PolicyEvaluator{UseLocalPolicy: expectedPolicyFilePath} - ghConn := &gh_control.GitHubConnection{Owner: "local", Repo: "local", Branch: "main"} + ghConn := newTestGhBranchConnection("local", "local", "main") - verifiedLevels, policyPath, err := p.EvaluateProv(context.Background(), ghConn, provenanceStatement) + verifiedLevels, policyPath, err := pe.EvaluateSourceProv(context.Background(), ghConn, provenanceStatement) if err != nil { - t.Errorf("EvaluateProv() error = %v, want nil", err) + t.Errorf("EvaluateSourceProv() error = %v, want nil", err) } if policyPath != expectedPolicyFilePath { - t.Errorf("EvaluateProv() policyPath = %q, want %q", policyPath, expectedPolicyFilePath) + t.Errorf("EvaluateSourceProv() policyPath = %q, want %q", policyPath, expectedPolicyFilePath) } - expected := slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel3), slsa_types.ReviewEnforced, slsa_types.ImmutableTags} - if !reflect.DeepEqual(verifiedLevels, expected) { - t.Errorf("EvaluateProv() verifiedLevels = %v, want %v", verifiedLevels, expected) + expectedLevels := slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel3), slsa_types.ReviewEnforced, slsa_types.ImmutableTags} + if !reflect.DeepEqual(verifiedLevels, expectedLevels) { + t.Errorf("EvaluateSourceProv() verifiedLevels = %v, want %v", verifiedLevels, expectedLevels) } } -func TestEvaluateProv_Failure(t *testing.T) { - now := time.Now() - earlier := now.Add(-time.Hour) - +func TestEvaluateSourceProv_Failure(t *testing.T) { // Controls for mock provenance - continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlier} - provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlier} - reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlier} - immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlier} + continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlierFixedTime} + provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlierFixedTime} + reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlierFixedTime} + immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlierFixedTime} // Policies policyL3ReviewTagsNow := RepoPolicy{ ProtectedBranches: []ProtectedBranch{ - {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: true, Since: now}, + {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, Since: fixedTime}, }, + ProtectedTag: &ProtectedTag{Since: fixedTime, ImmutableTags: true}, } policyL1NoExtrasNow := RepoPolicy{ // Policy for default/branch not found cases ProtectedBranches: []ProtectedBranch{ - {Name: "otherbranch", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, Since: now}, + {Name: "otherbranch", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, Since: fixedTime}, }, + ProtectedTag: nil, } // Valid Provenance Predicate (attest.SourceProvenancePred) @@ -186,7 +220,7 @@ func TestEvaluateProv_Failure(t *testing.T) { policyContent: "not valid policy json", provenanceStatement: createStatementForTest(t, validProvPredicateL3Controls, attest.SourceProvPredicateType), ghConnBranch: "main", - expectedErrorContains: "invalid character 'o' in literal null (expecting 'u')", // Error from getBranchPolicy via getLocalPolicy + expectedErrorContains: "invalid character 'o' in literal null (expecting 'u')", // Error from getPolicy via getLocalPolicy }, { name: "Non-existent Policy File -> Error", @@ -233,50 +267,49 @@ func TestEvaluateProv_Failure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - p := &Policy{} + pe := &PolicyEvaluator{} var ghConn *gh_control.GitHubConnection if tt.name == "Non-existent Policy File -> Error" { - p.UseLocalPolicy = "/path/to/nonexistent/test/policy.json" // Specific path for this test + pe.UseLocalPolicy = "/path/to/nonexistent/test/policy.json" // Specific path for this test } else if tt.policyContent != nil { policyFilePath := createTempPolicyFile(t, tt.policyContent) defer os.Remove(policyFilePath) - p.UseLocalPolicy = policyFilePath + pe.UseLocalPolicy = policyFilePath } - ghConn = &gh_control.GitHubConnection{Owner: "local", Repo: "local", Branch: tt.ghConnBranch} + ghConn = newTestGhBranchConnection("local", "local", tt.ghConnBranch) - _, _, err := p.EvaluateProv(ctx, ghConn, tt.provenanceStatement) + _, _, err := pe.EvaluateSourceProv(ctx, ghConn, tt.provenanceStatement) if err == nil { - t.Errorf("EvaluateProv() error = nil, want non-nil error containing %q", tt.expectedErrorContains) + t.Errorf("EvaluateSourceProv() error = nil, want non-nil error containing %q", tt.expectedErrorContains) } else if !strings.Contains(err.Error(), tt.expectedErrorContains) { - t.Errorf("EvaluateProv() error = %q, want error containing %q", err.Error(), tt.expectedErrorContains) + t.Errorf("EvaluateSourceProv() error = %q, want error containing %q", err.Error(), tt.expectedErrorContains) } - // Not checking verifiedLevels or policyPath in failure cases }) } } func TestEvaluateControl_Success(t *testing.T) { - now := time.Now() - earlier := now.Add(-time.Hour) - later := now.Add(time.Hour) - // Controls - continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlier} - provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlier} - reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlier} - immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlier} + continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlierFixedTime} + provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlierFixedTime} + reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlierFixedTime} + immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlierFixedTime} // Policies policyL3ReviewTagsNow := RepoPolicy{ ProtectedBranches: []ProtectedBranch{ - {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: true, Since: now}, + {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, Since: fixedTime}, + }, + ProtectedTag: &ProtectedTag{ + Since: fixedTime, + ImmutableTags: true, }, } policyL1NoExtrasNow := RepoPolicy{ ProtectedBranches: []ProtectedBranch{ - {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, Since: now}, + {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, Since: fixedTime}, }, } @@ -292,7 +325,7 @@ func TestEvaluateControl_Success(t *testing.T) { name: "Commit time before policy Since -> SLSA Level 1", policyContent: policyL3ReviewTagsNow, controlStatus: &gh_control.GhControlStatus{ - CommitPushTime: earlier, // Commit time before policyL3ReviewTagsNow.Since (now) + CommitPushTime: earlierFixedTime, // Commit time before policyL3ReviewTagsNow.Since (now) Controls: slsa_types.Controls{continuityEnforcedEarlier, provenanceAvailableEarlier, reviewEnforcedEarlier, immutableTagsEarlier}, }, ghConnBranch: "main", @@ -303,7 +336,7 @@ func TestEvaluateControl_Success(t *testing.T) { name: "Commit time after policy Since, controls meet policy -> Expected levels", policyContent: policyL3ReviewTagsNow, controlStatus: &gh_control.GhControlStatus{ - CommitPushTime: later, // Commit time after policyL3ReviewTagsNow.Since (now) + CommitPushTime: laterFixedTime, Controls: slsa_types.Controls{continuityEnforcedEarlier, provenanceAvailableEarlier, reviewEnforcedEarlier, immutableTagsEarlier}, }, ghConnBranch: "main", @@ -314,7 +347,7 @@ func TestEvaluateControl_Success(t *testing.T) { name: "Branch not in policy, commit after default policy since -> Default policy (SLSA L1)", policyContent: policyL1NoExtrasNow, // main is in policy, but we test "develop" controlStatus: &gh_control.GhControlStatus{ - CommitPushTime: later, // After default policy's implicit Since (zero time) + CommitPushTime: laterFixedTime, Controls: slsa_types.Controls{continuityEnforcedEarlier, provenanceAvailableEarlier, reviewEnforcedEarlier, immutableTagsEarlier}, }, ghConnBranch: "develop", // Testing "develop" branch @@ -326,21 +359,21 @@ func TestEvaluateControl_Success(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - p := &Policy{} + pe := &PolicyEvaluator{} var ghConn *gh_control.GitHubConnection actualPolicyPath := tt.expectedPolicyPath // May be overridden for local temp file if tt.policyContent != nil { policyFilePath := createTempPolicyFile(t, tt.policyContent) defer os.Remove(policyFilePath) - p.UseLocalPolicy = policyFilePath + pe.UseLocalPolicy = policyFilePath if tt.expectedPolicyPath == "TEMP_POLICY_FILE_PATH" { actualPolicyPath = policyFilePath } } - ghConn = &gh_control.GitHubConnection{Owner: "local", Repo: "local", Branch: tt.ghConnBranch} + ghConn = newTestGhBranchConnection("local", "local", tt.ghConnBranch) - verifiedLevels, policyPath, err := p.EvaluateControl(ctx, ghConn, tt.controlStatus) + verifiedLevels, policyPath, err := pe.EvaluateControl(ctx, ghConn, tt.controlStatus) if err != nil { t.Errorf("EvaluateControl() error = %v, want nil", err) @@ -368,9 +401,9 @@ func TestEvaluateControl_Failure(t *testing.T) { continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlier} // Policies - policyL3ReviewTagsNow := RepoPolicy{ + policyL3Review := RepoPolicy{ ProtectedBranches: []ProtectedBranch{ - {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: true, Since: now}, + {Name: "main", TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, Since: now}, }, } @@ -383,7 +416,7 @@ func TestEvaluateControl_Failure(t *testing.T) { }{ { name: "Commit time after policy Since, controls DO NOT meet policy -> Error", - policyContent: policyL3ReviewTagsNow, // Requires L3, Review, Tags + policyContent: policyL3Review, // Requires L3, Review, Tags controlStatus: &gh_control.GhControlStatus{ CommitPushTime: later, // Commit time after policy.Since Controls: slsa_types.Controls{continuityEnforcedEarlier}, // Only meets L2 @@ -406,7 +439,7 @@ func TestEvaluateControl_Failure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - p := &Policy{} + pe := &PolicyEvaluator{} var ghConn *gh_control.GitHubConnection policyFilePath := "" // Default to empty, will be set if policyContent is not nil @@ -423,12 +456,12 @@ func TestEvaluateControl_Failure(t *testing.T) { if policyFilePath != "" { // Ensure removal only if a file was created defer os.Remove(policyFilePath) } - p.UseLocalPolicy = policyFilePath + pe.UseLocalPolicy = policyFilePath } - ghConn = &gh_control.GitHubConnection{Owner: "local", Repo: "local", Branch: tt.ghConnBranch} + ghConn = newTestGhBranchConnection("local", "local", tt.ghConnBranch) - _, _, err := p.EvaluateControl(ctx, ghConn, tt.controlStatus) + _, _, err := pe.EvaluateControl(ctx, ghConn, tt.controlStatus) if err == nil { t.Errorf("EvaluateControl() error = nil, want non-nil error containing %q", tt.expectedErrorContains) @@ -457,19 +490,14 @@ func setupMockGitHubTestEnv(t *testing.T, targetOwner string, targetRepo string, } ghClient.BaseURL = baseURL - ghConn := &gh_control.GitHubConnection{ - Owner: targetOwner, - Repo: targetRepo, - Branch: targetBranch, - Client: ghClient, - } + ghConn := gh_control.NewGhConnectionWithClient(targetOwner, targetRepo, targetBranch, ghClient) return ghConn, server } // assertProtectedBranchEquals compares two ProtectedBranch structs for equality, // optionally ignoring the 'Since' field. It provides a detailed error message // if they are not equal. -func assertProtectedBranchEquals(t *testing.T, got *ProtectedBranch, expected ProtectedBranch, ignoreSince bool, customMessage string) { +func assertProtectedBranchEquals(t *testing.T, got *ProtectedBranch, expected ProtectedBranch, ignoreSince bool) { t.Helper() if got == nil { @@ -478,15 +506,7 @@ func assertProtectedBranchEquals(t *testing.T, got *ProtectedBranch, expected Pr // implying that a nil 'got' might be acceptable. However, for this helper, // we assume if 'expected' is provided, 'got' should be non-nil. if expected != (ProtectedBranch{}) { - // Note: The original Fatalf message included customMessage formatting, - // which is simplified here as customMessage is now just a string. - // Consider if this part needs more sophisticated handling if customMessage is expected to be a format string. - // For now, just appending it. - fatalMsg := fmt.Sprintf("Expected a non-nil ProtectedBranch, but got nil. Expected: %+v.", expected) - if customMessage != "" { - fatalMsg = fmt.Sprintf("%s %s", customMessage, fatalMsg) - } - t.Fatalf(fatalMsg) + t.Fatalf("Expected a non-nil ProtectedBranch, but got nil. Expected: %+v.", expected) } // If 'expected' is also a zero-value struct, then a nil 'got' is considered a match. return @@ -513,10 +533,6 @@ func assertProtectedBranchEquals(t *testing.T, got *ProtectedBranch, expected Pr if !reflect.DeepEqual(actualCopy, expectedCopy) || !sinceMatch { var errorMessage strings.Builder - if customMessage != "" { - errorMessage.WriteString(customMessage) - errorMessage.WriteString("\n") - } errorMessage.WriteString(fmt.Sprintf("ProtectedBranch structs not equal:\nExpected: %+v\nGot: %+v", expected, actual)) if !sinceMatch { errorMessage.WriteString(fmt.Sprintf("\nSpecifically, 'Since' fields were not equal (Expected.Since: %v, Got.Since: %v)", expected.Since, actual.Since)) @@ -535,7 +551,6 @@ const ( ) func TestComputeEligibleSlsaLevel(t *testing.T) { - fixedTime := time.Now() continuityEnforcedControl := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: fixedTime} provenanceAvailableControl := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: fixedTime} @@ -584,36 +599,31 @@ func TestComputeEligibleSlsaLevel(t *testing.T) { } } -func TestEvaluateControls(t *testing.T) { - now := time.Now() - earlier := now.Add(-time.Hour) - // later := now.Add(time.Hour) // Unused - +func TestEvaluateBranchControls(t *testing.T) { // Controls - continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlier} - provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlier} - reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlier} - immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlier} - - // continuityEnforcedNow := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: now} // Unused - // provenanceAvailableNow := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: now} // Unused - // reviewEnforcedNow := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: now} // Unused - immutableTagsNow := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: now} + continuityEnforcedEarlier := slsa_types.Control{Name: slsa_types.ContinuityEnforced, Since: earlierFixedTime} + provenanceAvailableEarlier := slsa_types.Control{Name: slsa_types.ProvenanceAvailable, Since: earlierFixedTime} + reviewEnforcedEarlier := slsa_types.Control{Name: slsa_types.ReviewEnforced, Since: earlierFixedTime} + immutableTagsEarlier := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: earlierFixedTime} + immutableTagsNow := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: fixedTime} // Branch Policies - // Policy Since 'now' for most cases, implying controls should be active by 'now' or 'earlier'. - policyL3ReviewTagsNow := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: true, Since: now} - policyL1NoExtrasNow := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, RequireReview: false, ImmutableTags: false, Since: now} - policyL2ReviewNoTagsNow := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: true, ImmutableTags: false, Since: now} - policyL2NoReviewTagsNow := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: false, ImmutableTags: true, Since: now} - policyL3ReviewNoTagsNow := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: false, Since: now} + policyL3Review := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, Since: fixedTime} + policyL1NoExtras := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, RequireReview: false, Since: fixedTime} + policyL2Review := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: true, Since: fixedTime} + policyL2NoReview := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: false, Since: fixedTime} + + // Tag policies + immutableTagPolicy := ProtectedTag{Since: fixedTime, ImmutableTags: true} + noImmutableTagPolicy := ProtectedTag{Since: fixedTime, ImmutableTags: false} // Policy Since 'earlier' for testing control.Since > policy.Since - policyL2TagsEarlier := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: false, ImmutableTags: true, Since: earlier} + policyL2TagsEarlier := ProtectedBranch{TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: false, Since: earlierFixedTime} tests := []struct { name string branchPolicy *ProtectedBranch + tagPolicy *ProtectedTag controls slsa_types.Controls expectedLevels slsa_types.SourceVerifiedLevels expectError bool @@ -621,43 +631,49 @@ func TestEvaluateControls(t *testing.T) { }{ { name: "Success - All Met (L3, Review, Tags)", - branchPolicy: &policyL3ReviewTagsNow, + branchPolicy: &policyL3Review, + tagPolicy: &immutableTagPolicy, controls: slsa_types.Controls{continuityEnforcedEarlier, provenanceAvailableEarlier, reviewEnforcedEarlier, immutableTagsEarlier}, expectedLevels: slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel3), slsa_types.ReviewEnforced, slsa_types.ImmutableTags}, expectError: false, }, { name: "Success - Only SLSA Level (L1)", - branchPolicy: &policyL1NoExtrasNow, + branchPolicy: &policyL1NoExtras, + tagPolicy: &noImmutableTagPolicy, controls: slsa_types.Controls{}, // L1 is met by default if policy targets L1 and other conditions pass expectedLevels: slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel1)}, expectError: false, }, { name: "Success - SLSA & Review (L2, Review)", - branchPolicy: &policyL2ReviewNoTagsNow, + branchPolicy: &policyL2Review, + tagPolicy: &noImmutableTagPolicy, controls: slsa_types.Controls{continuityEnforcedEarlier, reviewEnforcedEarlier}, // Provenance not needed for L2 expectedLevels: slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel2), slsa_types.ReviewEnforced}, expectError: false, }, { name: "Success - SLSA & Tags (L2, Tags)", - branchPolicy: &policyL2NoReviewTagsNow, + branchPolicy: &policyL2NoReview, + tagPolicy: &immutableTagPolicy, controls: slsa_types.Controls{continuityEnforcedEarlier, immutableTagsEarlier}, // Provenance not needed for L2 expectedLevels: slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel2), slsa_types.ImmutableTags}, expectError: false, }, { name: "Error - computeSlsaLevel Fails (Policy L3, Controls L1)", - branchPolicy: &policyL3ReviewTagsNow, // Wants L3 - controls: slsa_types.Controls{}, // Only eligible for L1 + branchPolicy: &policyL3Review, // Wants L3 + tagPolicy: &noImmutableTagPolicy, + controls: slsa_types.Controls{}, // Only eligible for L1 expectedLevels: slsa_types.SourceVerifiedLevels{}, expectError: true, expectedErrorContains: "error computing slsa level: policy sets target level SLSA_SOURCE_LEVEL_3, but branch is only eligible for SLSA_SOURCE_LEVEL_1", }, { name: "Error - computeReviewEnforced Fails (Policy L2+Review, Review control missing)", - branchPolicy: &policyL2ReviewNoTagsNow, // Wants L2 & Review + branchPolicy: &policyL2Review, // Wants L2 & Review + tagPolicy: &noImmutableTagPolicy, controls: slsa_types.Controls{continuityEnforcedEarlier}, // Eligible for L2, but Review control missing expectedLevels: slsa_types.SourceVerifiedLevels{}, expectError: true, @@ -665,15 +681,18 @@ func TestEvaluateControls(t *testing.T) { }, { name: "Error - computeImmutableTags Fails (Policy L2+Tags, Tag control Since later than Policy Since)", - branchPolicy: &policyL2TagsEarlier, // Wants L2 & Tags, Policy.Since = earlier + branchPolicy: &policyL2TagsEarlier, // Wants L2 & Tags, Policy.Since = earlier + tagPolicy: &ProtectedTag{Since: earlierFixedTime, ImmutableTags: true}, controls: slsa_types.Controls{continuityEnforcedEarlier, immutableTagsNow}, // Eligible L2, Tag.Since = now expectedLevels: slsa_types.SourceVerifiedLevels{}, expectError: true, expectedErrorContains: "error computing tag immutability enforced: policy requires immutable tags since", // ... but that control has only been enabled since ... }, { - name: "Success - Mixed Requirements (L3, Review, No Tags)", - branchPolicy: &policyL3ReviewNoTagsNow, // Wants L3, Review, No Tags + name: "Success - Mixed Requirements (L3, Review, No Tags)", + branchPolicy: &policyL3Review, + tagPolicy: &noImmutableTagPolicy, + // Wants L3, Review, No Tags controls: slsa_types.Controls{continuityEnforcedEarlier, provenanceAvailableEarlier, reviewEnforcedEarlier}, // Satisfies L3 & Review expectedLevels: slsa_types.SourceVerifiedLevels{string(slsa_types.SlsaSourceLevel3), slsa_types.ReviewEnforced}, expectError: false, @@ -682,17 +701,17 @@ func TestEvaluateControls(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotLevels, err := evaluateControls(tt.branchPolicy, tt.controls) + gotLevels, err := evaluateBranchControls(tt.branchPolicy, tt.tagPolicy, tt.controls) if tt.expectError { if err == nil { - t.Errorf("evaluateControls() error = nil, want non-nil error containing %q", tt.expectedErrorContains) + t.Errorf("evaluateBranchControls() error = nil, want non-nil error containing %q", tt.expectedErrorContains) } else if !strings.Contains(err.Error(), tt.expectedErrorContains) { - t.Errorf("evaluateControls() error = %q, want error containing %q", err.Error(), tt.expectedErrorContains) + t.Errorf("evaluateBranchControls() error = %q, want error containing %q", err.Error(), tt.expectedErrorContains) } } else { if err != nil { - t.Errorf("evaluateControls() error = %v, want nil", err) + t.Errorf("evaluateBranchControls() error = %v, want nil", err) } } @@ -714,7 +733,7 @@ func TestEvaluateControls(t *testing.T) { } else if len(gotLevels) == 0 && tt.expectedLevels == nil { // similar to above, if got is empty and expected is nil } else { - t.Errorf("evaluateControls() gotLevels = %v, want %v", gotLevels, tt.expectedLevels) + t.Errorf("evaluateBranchControls() gotLevels = %v, want %v", gotLevels, tt.expectedLevels) } } }) @@ -726,9 +745,9 @@ func TestComputeImmutableTags(t *testing.T) { earlier := now.Add(-time.Hour) // Branch Policies - policyRequiresImmutableTagsNow := ProtectedBranch{ImmutableTags: true, Since: now} - policyRequiresImmutableTagsEarlier := ProtectedBranch{ImmutableTags: true, Since: earlier} - policyNotRequiresImmutableTags := ProtectedBranch{ImmutableTags: false, Since: now} + policyRequiresImmutableTagsNow := ProtectedTag{ImmutableTags: true, Since: now} + policyRequiresImmutableTagsEarlier := ProtectedTag{ImmutableTags: true, Since: earlier} + policyNotRequiresImmutableTags := ProtectedTag{ImmutableTags: false, Since: now} // Controls immutableTagsControlEnabledNow := slsa_types.Control{Name: slsa_types.ImmutableTags, Since: now} @@ -736,7 +755,7 @@ func TestComputeImmutableTags(t *testing.T) { tests := []struct { name string - branchPolicy *ProtectedBranch + tagPolicy *ProtectedTag controls slsa_types.Controls expectedImmutableEnforced bool expectError bool @@ -744,21 +763,21 @@ func TestComputeImmutableTags(t *testing.T) { }{ { name: "Policy requires immutable tags, control compliant (Policy.Since >= Control.Since)", - branchPolicy: &policyRequiresImmutableTagsNow, + tagPolicy: &policyRequiresImmutableTagsNow, controls: slsa_types.Controls{immutableTagsControlEnabledNow}, // Policy.Since == Control.Since expectedImmutableEnforced: true, expectError: false, }, { name: "Policy does not require immutable tags - control state irrelevant", - branchPolicy: &policyNotRequiresImmutableTags, + tagPolicy: &policyNotRequiresImmutableTags, controls: slsa_types.Controls{}, // Control state explicitly shown as irrelevant expectedImmutableEnforced: false, expectError: false, }, { name: "Policy requires immutable tags, control not present: fail", - branchPolicy: &policyRequiresImmutableTagsNow, + tagPolicy: &policyRequiresImmutableTagsNow, controls: slsa_types.Controls{}, // Immutable tags control missing expectedImmutableEnforced: false, expectError: true, @@ -766,7 +785,7 @@ func TestComputeImmutableTags(t *testing.T) { }, { name: "Policy requires immutable tags, control enabled, Policy.Since < Control.Since: fail", - branchPolicy: &policyRequiresImmutableTagsEarlier, // Policy.Since is 'earlier' + tagPolicy: &policyRequiresImmutableTagsEarlier, // Policy.Since is 'earlier' controls: slsa_types.Controls{immutableTagsControlEnabledNow}, // Control.Since is 'now' expectedImmutableEnforced: false, expectError: true, @@ -776,7 +795,7 @@ func TestComputeImmutableTags(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotEnforced, err := computeImmutableTags(tt.branchPolicy, tt.controls) + gotEnforced, err := computeImmutableTags(tt.tagPolicy, tt.controls) if tt.expectError { if err == nil { @@ -1099,67 +1118,56 @@ func TestComputeEligibleSince(t *testing.T) { } } -func TestGetBranchPolicy_Local_SpecificFound(t *testing.T) { - fixedTime := time.Unix(1678886400, 0) // March 15, 2023 00:00:00 UTC +func assertPolicyResultEquals(t *testing.T, ctx context.Context, ghConn *gh_control.GitHubConnection, pe *PolicyEvaluator, expectedPolicy *RepoPolicy, expectedBranchPolicy *ProtectedBranch, expectedPath string) { + rp, gotPath, err := pe.getPolicy(ctx, ghConn) - tests := []struct { - name string - branchName string - policyToCreate RepoPolicy - expectedBranch ProtectedBranch - }{ - { - name: "local policy exists with target branch", - branchName: "feature", - policyToCreate: RepoPolicy{ - ProtectedBranches: []ProtectedBranch{ - {Name: "feature", Since: fixedTime, TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, RequireReview: true, ImmutableTags: true}, - {Name: "main", Since: fixedTime, TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1}, // Another branch to ensure correct one is picked - }, - }, - expectedBranch: ProtectedBranch{ - Name: "feature", - Since: fixedTime, - TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel2, - RequireReview: true, - ImmutableTags: true, - }, - }, + if err != nil { + t.Fatalf("getPolicy() error = %v, want nil", err) + } + if gotPath != expectedPath { + t.Errorf("getPolicy() gotPath = %q, want %q (temp file path)", gotPath, expectedPath) + } + if expectedPolicy == nil { + if rp != nil { + t.Fatalf("getPolicy() expectedPolicy == nil but got non-nil policy %+v", rp) + } + return // quite while we're ahead } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - ghConn := &gh_control.GitHubConnection{Owner: "any", Repo: "any", Branch: tt.branchName} - p := &Policy{} + if rp == nil { + t.Fatalf("getPolicy() rp is nil but expectedPolicy is not") + } - policyFilePath := createTempPolicyFile(t, tt.policyToCreate) - defer os.Remove(policyFilePath) - p.UseLocalPolicy = policyFilePath + // TODO: check the rest of the contents of expectedPolicy? - gotBranch, gotPath, err := p.getBranchPolicy(ctx, ghConn) + gotPb := rp.getBranchPolicy(gh_control.GetBranchFromRef(ghConn.GetFullRef())) - if err != nil { - t.Fatalf("getBranchPolicy() error = %v, want nil", err) - } - if gotPath != policyFilePath { - t.Errorf("getBranchPolicy() gotPath = %q, want %q (temp file path)", gotPath, policyFilePath) - } - if gotBranch == nil { - // This check is important because tt.expectedBranch is non-zero in this test. - // assertProtectedBranchEquals would also fatalf, but this gives a slightly more direct message. - t.Fatalf("getBranchPolicy() gotBranch is nil, expected non-nil: %+v for test case %s", tt.expectedBranch, tt.name) - } - - message := fmt.Sprintf("Mismatch in TestGetBranchPolicy_Local_SpecificFound for test case '%s', branch '%s'", tt.name, tt.branchName) - assertProtectedBranchEquals(t, gotBranch, tt.expectedBranch, false, message) - }) + if expectedBranchPolicy == nil { + if gotPb != nil { + t.Fatalf("getPolicy() expectedBranchPolicy == nil but got non-nil branch policy %+v", rp) + } + return } + + assertProtectedBranchEquals(t, gotPb, *expectedBranchPolicy, false) } -func TestGetBranchPolicy_Local_DefaultCases(t *testing.T) { - fixedTime := time.Unix(1678886400, 0) // March 15, 2023 00:00:00 UTC +func TestGetPolicy_Local_SpecificFound(t *testing.T) { + pb := createTestBranchPolicy("feature") + policyToCreate := createTestPolicy(pb) + + ctx := context.Background() + ghConn := newTestGhBranchConnection("any", "any", "feature") + pe := &PolicyEvaluator{} + + policyFilePath := createTempPolicyFile(t, policyToCreate) + defer os.Remove(policyFilePath) + pe.UseLocalPolicy = policyFilePath + + assertPolicyResultEquals(t, ctx, ghConn, pe, &policyToCreate, &pb, policyFilePath) +} +func TestGetPolicy_Local_NotFoundCases(t *testing.T) { tests := []struct { name string branchName string @@ -1189,39 +1197,19 @@ func TestGetBranchPolicy_Local_DefaultCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - ghConn := &gh_control.GitHubConnection{Owner: "any", Repo: "any", Branch: tt.branchName} - p := &Policy{} + ghConn := newTestGhBranchConnection("any", "any", tt.branchName) + pe := &PolicyEvaluator{} policyFilePath := createTempPolicyFile(t, tt.policyToCreate) defer os.Remove(policyFilePath) - p.UseLocalPolicy = policyFilePath - - gotBranch, gotPath, err := p.getBranchPolicy(ctx, ghConn) + pe.UseLocalPolicy = policyFilePath - if err != nil { - t.Fatalf("getBranchPolicy() error = %v, want nil", err) - } - if gotPath != "DEFAULT" { - t.Errorf("getBranchPolicy() gotPath = %q, want 'DEFAULT'", gotPath) - } - if gotBranch == nil { - t.Fatalf("getBranchPolicy() gotBranch is nil, want default policy for branch %q", ghConn.Branch) - } - - expectedDefaultBranch := ProtectedBranch{ - Name: ghConn.Branch, // ghConn.Branch is populated from tt.branchName - TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, - RequireReview: false, - ImmutableTags: false, - // Since is implicitly its zero value (time.Time{}), and will be ignored by the helper - } - message := fmt.Sprintf("Mismatch in TestGetBranchPolicy_Local_DefaultCases for test case '%s', branch '%s'", tt.name, tt.branchName) - assertProtectedBranchEquals(t, gotBranch, expectedDefaultBranch, true, message) + assertPolicyResultEquals(t, ctx, ghConn, pe, &tt.policyToCreate, nil, policyFilePath) }) } } -func TestGetBranchPolicy_Local_ErrorCases(t *testing.T) { +func TestGetPolicy_Local_ErrorCases(t *testing.T) { tests := []struct { name string branchName string @@ -1245,8 +1233,8 @@ func TestGetBranchPolicy_Local_ErrorCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - ghConn := &gh_control.GitHubConnection{Owner: "any", Repo: "any", Branch: tt.branchName} - p := &Policy{} + ghConn := newTestGhBranchConnection("any", "any", tt.branchName) + pe := &PolicyEvaluator{} var policyFilePath string if tt.useLocalPolicyPath == "CREATE_TEMP" { @@ -1255,125 +1243,78 @@ func TestGetBranchPolicy_Local_ErrorCases(t *testing.T) { } policyFilePath = createTempPolicyFile(t, tt.policyFileContent) defer os.Remove(policyFilePath) // Ensure cleanup even if test expects error - p.UseLocalPolicy = policyFilePath + pe.UseLocalPolicy = policyFilePath } else { - p.UseLocalPolicy = tt.useLocalPolicyPath // For non-existent file + pe.UseLocalPolicy = tt.useLocalPolicyPath // For non-existent file } - gotBranch, gotPath, err := p.getBranchPolicy(ctx, ghConn) + gotRp, gotPath, err := pe.getPolicy(ctx, ghConn) if err == nil { - t.Errorf("getBranchPolicy() error = nil, want non-nil error") + t.Errorf("getPolicy() error = nil, want non-nil error") } - if gotBranch != nil { - t.Errorf("getBranchPolicy() gotBranch = %v, want nil", gotBranch) + if gotRp != nil { + t.Errorf("getPolicy() gotRp = %v, want nil", gotRp) } if gotPath != "" { - t.Errorf("getBranchPolicy() gotPath = %q, want \"\"", gotPath) + t.Errorf("getPolicy() gotPath = %q, want \"\"", gotPath) } }) } } -func TestGetBranchPolicy_Remote_SpecificFound(t *testing.T) { - fixedTime := time.Unix(1678886400, 0) // March 15, 2023 00:00:00 UTC - mockHTMLURL := "https://github.example.com/policy.json" - - tests := []struct { - name string - targetOwner string - targetRepo string - targetBranch string - mockPolicyContent RepoPolicy - expectedBranch ProtectedBranch - // expectedPath is always mockHTMLURL for this test function - }{ - { - name: "remote policy fetch success, branch found", - targetOwner: "test-owner", - targetRepo: "test-repo", - targetBranch: "main", - mockPolicyContent: RepoPolicy{ - ProtectedBranches: []ProtectedBranch{ - {Name: "main", Since: fixedTime, TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, RequireReview: true, ImmutableTags: true}, - {Name: "other", Since: fixedTime, TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1}, // Ensure correct branch is picked - }, - }, - expectedBranch: ProtectedBranch{ - Name: "main", - Since: fixedTime, - TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3, - RequireReview: true, - ImmutableTags: true, - }, - }, - } +func TestGetPolicy_Remote_SpecificFound(t *testing.T) { + mockPolicyPath := "https://github.example.com/policy.json" + targetOwner := "owner" + targetBranch := "feature" + targetRepo := "repo" + pb := createTestBranchPolicy(targetBranch) + expectedPolicy := createTestPolicy(pb) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - p := &Policy{UseLocalPolicy: ""} - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - validateMockServerRequestPath(t, r, tt.targetOwner, tt.targetRepo, tt.targetBranch) - w.WriteHeader(http.StatusOK) // Always OK for this test function - policyJSON, err := json.Marshal(tt.mockPolicyContent) - if err != nil { - t.Fatalf("Failed to marshal RepoPolicy for mock: %v", err) - } - encodedContent := base64.StdEncoding.EncodeToString(policyJSON) - mockFileContent := &github.RepositoryContent{ - Type: github.String("file"), - Encoding: github.String("base64"), - Content: github.String(encodedContent), - HTMLURL: github.String(mockHTMLURL), - } - respData, err := json.Marshal(mockFileContent) - if err != nil { - t.Fatalf("Failed to marshal mock RepositoryContent: %v", err) - } - _, _ = w.Write(respData) - }) - - ghConn, mockServer := setupMockGitHubTestEnv(t, tt.targetOwner, tt.targetRepo, tt.targetBranch, handler) - defer mockServer.Close() + ctx := context.Background() + pe := &PolicyEvaluator{UseLocalPolicy: ""} - gotBranch, gotPath, err := p.getBranchPolicy(ctx, ghConn) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + validateMockServerRequestPath(t, r, targetOwner, targetRepo, targetBranch) + w.WriteHeader(http.StatusOK) // Always OK for this test function + policyJSON, err := json.Marshal(expectedPolicy) + if err != nil { + t.Fatalf("Failed to marshal RepoPolicy for mock: %v", err) + } + encodedContent := base64.StdEncoding.EncodeToString(policyJSON) + mockFileContent := &github.RepositoryContent{ + Type: github.String("file"), + Encoding: github.String("base64"), + Content: github.String(encodedContent), + HTMLURL: github.String(mockPolicyPath), + } + respData, err := json.Marshal(mockFileContent) + if err != nil { + t.Fatalf("Failed to marshal mock RepositoryContent: %v", err) + } + _, _ = w.Write(respData) + }) - if err != nil { - t.Fatalf("getBranchPolicy() error = %v, want nil", err) - } - if gotPath != mockHTMLURL { - t.Errorf("getBranchPolicy() gotPath = %q, want %q", gotPath, mockHTMLURL) - } - if gotBranch == nil { - t.Fatalf("getBranchPolicy() gotBranch is nil, expected non-nil: %+v for test case %s", tt.expectedBranch, tt.name) - } + ghConn, mockServer := setupMockGitHubTestEnv(t, targetOwner, targetRepo, targetBranch, handler) + defer mockServer.Close() - message := fmt.Sprintf("Mismatch in TestGetBranchPolicy_Remote_SpecificFound for test case '%s', branch '%s'", tt.name, tt.targetBranch) - assertProtectedBranchEquals(t, gotBranch, tt.expectedBranch, false, message) - }) - } + assertPolicyResultEquals(t, ctx, ghConn, pe, &expectedPolicy, &pb, mockPolicyPath) } -func TestGetBranchPolicy_Remote_DefaultCases(t *testing.T) { - fixedTime := time.Unix(1678886400, 0) // March 15, 2023 00:00:00 UTC - mockHTMLURL := "https://github.example.com/policy.json" +func TestGetPolicy_Remote_NotFoundCases(t *testing.T) { + mockPolicyPath := "https://github.example.com/policy.json" + targetOwner := "test-owner" + targetRepo := "test-repo" tests := []struct { - name string - targetOwner string - targetRepo string - targetBranch string - mockHTTPStatus int - mockPolicyContent *RepoPolicy // Pointer to allow nil for 404 case - expectedPath string - // Default policy details are asserted directly in the test + name string + targetBranch string + mockHTTPStatus int + mockPolicyContent *RepoPolicy // Pointer to allow nil for 404 case + expectedPolicyPath string }{ { name: "remote policy fetch success, branch not found", - targetOwner: "test-owner", - targetRepo: "test-repo", targetBranch: "develop", mockHTTPStatus: http.StatusOK, mockPolicyContent: &RepoPolicy{ @@ -1381,44 +1322,38 @@ func TestGetBranchPolicy_Remote_DefaultCases(t *testing.T) { {Name: "main", Since: fixedTime, TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel3}, }, }, - expectedPath: "DEFAULT", // Changed from mockHTMLURL + expectedPolicyPath: mockPolicyPath, }, { - name: "remote policy fetch success, empty protected branches", - targetOwner: "test-owner", - targetRepo: "test-repo", - targetBranch: "main", - mockHTTPStatus: http.StatusOK, - mockPolicyContent: &RepoPolicy{ProtectedBranches: []ProtectedBranch{}}, - expectedPath: "DEFAULT", // Changed from mockHTMLURL + name: "remote policy fetch success, empty protected branches", + targetBranch: "main", + mockHTTPStatus: http.StatusOK, + mockPolicyContent: &RepoPolicy{ProtectedBranches: []ProtectedBranch{}}, + expectedPolicyPath: mockPolicyPath, }, { - name: "remote policy fetch success, nil protected branches", - targetOwner: "test-owner", - targetRepo: "test-repo", - targetBranch: "main", - mockHTTPStatus: http.StatusOK, - mockPolicyContent: &RepoPolicy{ProtectedBranches: nil}, - expectedPath: "DEFAULT", // Changed from mockHTMLURL + name: "remote policy fetch success, nil protected branches", + targetBranch: "main", + mockHTTPStatus: http.StatusOK, + mockPolicyContent: &RepoPolicy{ProtectedBranches: nil}, + expectedPolicyPath: mockPolicyPath, }, { - name: "remote policy API returns 404 Not Found", - targetOwner: "test-owner", - targetRepo: "test-repo", - targetBranch: "main", - mockHTTPStatus: http.StatusNotFound, - mockPolicyContent: nil, // No policy content for 404 - expectedPath: "DEFAULT", + name: "remote policy API returns 404 Not Found", + targetBranch: "main", + mockHTTPStatus: http.StatusNotFound, + mockPolicyContent: nil, // No policy content for 404 + expectedPolicyPath: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - p := &Policy{UseLocalPolicy: ""} + pe := &PolicyEvaluator{UseLocalPolicy: ""} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - validateMockServerRequestPath(t, r, tt.targetOwner, tt.targetRepo, tt.targetBranch) + validateMockServerRequestPath(t, r, targetOwner, targetRepo, tt.targetBranch) w.WriteHeader(tt.mockHTTPStatus) if tt.mockHTTPStatus == http.StatusOK && tt.mockPolicyContent != nil { policyJSON, err := json.Marshal(*tt.mockPolicyContent) @@ -1427,10 +1362,10 @@ func TestGetBranchPolicy_Remote_DefaultCases(t *testing.T) { } encodedContent := base64.StdEncoding.EncodeToString(policyJSON) mockFileContent := &github.RepositoryContent{ - Type: github.String("file"), - Encoding: github.String("base64"), - Content: github.String(encodedContent), - HTMLURL: github.String(mockHTMLURL), + Type: github.Ptr("file"), + Encoding: github.Ptr("base64"), + Content: github.Ptr(encodedContent), + HTMLURL: github.Ptr(mockPolicyPath), } respData, err := json.Marshal(mockFileContent) if err != nil { @@ -1440,35 +1375,15 @@ func TestGetBranchPolicy_Remote_DefaultCases(t *testing.T) { } }) - ghConn, mockServer := setupMockGitHubTestEnv(t, tt.targetOwner, tt.targetRepo, tt.targetBranch, handler) + ghConn, mockServer := setupMockGitHubTestEnv(t, targetOwner, targetRepo, tt.targetBranch, handler) defer mockServer.Close() - gotBranch, gotPath, err := p.getBranchPolicy(ctx, ghConn) - - if err != nil { - t.Fatalf("getBranchPolicy() error = %v, want nil", err) - } - if gotPath != tt.expectedPath { - t.Errorf("getBranchPolicy() gotPath = %q, want %q", gotPath, tt.expectedPath) - } - if gotBranch == nil { - t.Fatalf("getBranchPolicy() gotBranch is nil, want default policy for branch %q", ghConn.Branch) - } - - expectedDefaultBranch := ProtectedBranch{ - Name: ghConn.Branch, // ghConn.Branch is populated from tt.targetBranch - TargetSlsaSourceLevel: slsa_types.SlsaSourceLevel1, - RequireReview: false, - ImmutableTags: false, - // Since is implicitly its zero value (time.Time{}), and will be ignored by the helper - } - message := fmt.Sprintf("Mismatch in TestGetBranchPolicy_Remote_DefaultCases for test case '%s', branch '%s'", tt.name, tt.targetBranch) - assertProtectedBranchEquals(t, gotBranch, expectedDefaultBranch, true, message) + assertPolicyResultEquals(t, ctx, ghConn, pe, tt.mockPolicyContent, nil, tt.expectedPolicyPath) }) } } -func TestGetBranchPolicy_Remote_ServerError(t *testing.T) { +func TestGetPolicy_Remote_ServerError(t *testing.T) { ctx := context.Background() targetOwner := "test-owner" targetRepo := "test-repo" @@ -1482,22 +1397,22 @@ func TestGetBranchPolicy_Remote_ServerError(t *testing.T) { ghConn, mockServer := setupMockGitHubTestEnv(t, targetOwner, targetRepo, targetBranch, handler) defer mockServer.Close() - pol := Policy{UseLocalPolicy: ""} + pe := PolicyEvaluator{UseLocalPolicy: ""} // ghConn is now returned by setupMockGitHubTestEnv - branch, path, err := pol.getBranchPolicy(ctx, ghConn) + gotPolicy, gotPath, err := pe.getPolicy(ctx, ghConn) if err == nil { t.Errorf("Expected an error for server-side issues, got nil") } - if branch != nil { - t.Errorf("Expected branch to be nil on server error, got %v", branch) + if gotPolicy != nil { + t.Errorf("Expected policy to be nil on server error, got %v", gotPolicy) } - if path != "" { - t.Errorf("Expected path to be empty on server error, got %q", path) + if gotPath != "" { + t.Errorf("Expected path to be empty on server error, got %q", gotPath) } } -func TestGetBranchPolicy_Remote_MalformedJSON(t *testing.T) { +func TestGetPolicy_Remote_MalformedJSON(t *testing.T) { mockHTMLURL := "https://github.example.com/policy.json" // Still needed for one case tests := []struct { name string @@ -1544,18 +1459,17 @@ func TestGetBranchPolicy_Remote_MalformedJSON(t *testing.T) { ghConn, mockServer := setupMockGitHubTestEnv(t, targetOwner, targetRepo, targetBranch, handler) defer mockServer.Close() - pol := Policy{UseLocalPolicy: ""} - // ghConn is now returned by setupMockGitHubTestEnv + pe := PolicyEvaluator{UseLocalPolicy: ""} - branch, path, err := pol.getBranchPolicy(ctx, ghConn) + gotPolicy, gotPath, err := pe.getPolicy(ctx, ghConn) if err == nil { t.Errorf("Expected an error for malformed JSON, got nil") } - if branch != nil { - t.Errorf("Expected branch to be nil on malformed JSON, got %v", branch) + if gotPolicy != nil { + t.Errorf("Expected policy to be nil on malformed JSON, got %v", gotPolicy) } - if path != "" { // Path should be empty as we error out before using HTMLURL - t.Errorf("Expected path to be empty on malformed JSON, got %q", path) + if gotPath != "" { // Path should be empty as we error out before using HTMLURL + t.Errorf("Expected path to be empty on malformed JSON, got %q", gotPath) } }) } diff --git a/sourcetool/pkg/testsupport/mockverify.go b/sourcetool/pkg/testsupport/mockverify.go new file mode 100644 index 00000000..3a1bbfa7 --- /dev/null +++ b/sourcetool/pkg/testsupport/mockverify.go @@ -0,0 +1,29 @@ +package testsupport + +import ( + "fmt" + + spb "github.com/in-toto/attestation/go/v1" + "github.com/sigstore/sigstore-go/pkg/verify" + "google.golang.org/protobuf/encoding/protojson" +) + +type MockVerifier struct { +} + +func NewMockVerifier() *MockVerifier { + return &MockVerifier{} +} + +func (mv *MockVerifier) Verify(data string) (*verify.VerificationResult, error) { + var statement spb.Statement + err := protojson.Unmarshal([]byte(data), &statement) + if err != nil { + return nil, fmt.Errorf("error unmarshaling %s into statement", data) + } + + var vr verify.VerificationResult + vr.MediaType = "mockverifiermediatype" + vr.Statement = &statement + return &vr, nil +}