Skip to content

Commit 0a9571c

Browse files
committed
Make sure branch exists
1 parent 8b04cfc commit 0a9571c

File tree

4 files changed

+461
-36
lines changed

4 files changed

+461
-36
lines changed

tools/flakeguard/cmd/make_pr.go

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/go-git/go-git/v5"
13+
"github.com/go-git/go-git/v5/config"
1314
"github.com/go-git/go-git/v5/plumbing"
1415
"github.com/google/go-github/v72/github"
1516
"github.com/spf13/cobra"
@@ -80,13 +81,6 @@ func makePR(cmd *cobra.Command, args []string) error {
8081
return fmt.Errorf("failed to checkout default branch %s: %w", defaultBranch, err)
8182
}
8283

83-
fmt.Printf("Fetching latest changes from default branch '%s', tap your yubikey if it's blinking...", defaultBranch)
84-
err = repo.Fetch(&git.FetchOptions{})
85-
if err != nil && err != git.NoErrAlreadyUpToDate {
86-
return fmt.Errorf("failed to fetch latest: %w", err)
87-
}
88-
fmt.Println(" ✅")
89-
9084
fmt.Printf("Pulling latest changes from default branch '%s', tap your yubikey if it's blinking...", defaultBranch)
9185
err = targetRepoWorktree.Pull(&git.PullOptions{})
9286
if err != nil && err != git.NoErrAlreadyUpToDate {
@@ -104,28 +98,18 @@ func makePR(cmd *cobra.Command, args []string) error {
10498
return fmt.Errorf("failed to checkout new branch: %w", err)
10599
}
106100

107-
cleanUpBranch := true
108-
defer func() {
109-
if cleanUpBranch {
110-
fmt.Printf("Cleaning up branch %s...", branchName)
111-
// First checkout default branch
112-
err = targetRepoWorktree.Checkout(&git.CheckoutOptions{
113-
Branch: plumbing.NewBranchReferenceName(defaultBranch),
114-
Force: true, // Force checkout to discard any changes for a clean default branch
115-
})
116-
if err != nil {
117-
fmt.Printf("Failed to checkout default branch: %v\n", err)
118-
return
119-
}
120-
// Then delete the local branch
121-
err = repo.Storer.RemoveReference(plumbing.NewBranchReferenceName(branchName))
122-
if err != nil {
123-
fmt.Printf("Failed to remove local branch: %v\n", err)
124-
return
125-
}
126-
fmt.Println(" ✅")
127-
}
128-
}()
101+
// Push the new branch to GitHub before making any commits
102+
fmt.Printf("Pushing new branch '%s' to GitHub, tap your yubikey if it's blinking...", branchName)
103+
err = repo.Push(&git.PushOptions{
104+
RemoteName: "origin",
105+
RefSpecs: []config.RefSpec{
106+
config.RefSpec(fmt.Sprintf("refs/heads/%s:refs/heads/%s", branchName, branchName)),
107+
},
108+
})
109+
if err != nil {
110+
return fmt.Errorf("failed to push new branch to GitHub: %w", err)
111+
}
112+
fmt.Println(" ✅")
129113

130114
if len(currentlyFlakyEntries) == 0 {
131115
fmt.Println("No flaky tests found!")
@@ -148,11 +132,6 @@ func makePR(cmd *cobra.Command, args []string) error {
148132
return fmt.Errorf("failed to modify code to skip tests: %w", err)
149133
}
150134

151-
_, err = targetRepoWorktree.Add(".")
152-
if err != nil {
153-
return fmt.Errorf("failed to add changes: %w", err)
154-
}
155-
156135
fmt.Print("Committing changes, tap your yubikey if it's blinking...")
157136
sha, err := flake_git.MakeSignedCommit(repoPath, fmt.Sprintf("Skips flaky %d tests: %s", len(testsToSkip), strings.Join(jiraTickets, ", ")), branchName, githubToken)
158137
if err != nil {
@@ -250,7 +229,6 @@ func makePR(cmd *cobra.Command, args []string) error {
250229
return fmt.Errorf("failed to create PR, got bad status: %s\n%s", resp.Status, string(body))
251230
}
252231

253-
cleanUpBranch = false
254232
fmt.Printf("PR created! https://github.com/%s/%s/pull/%d\n", owner, repoName, createdPR.GetNumber())
255233
return nil
256234
}

tools/flakeguard/git/git.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/go-git/go-git/v5"
1616
"github.com/go-git/go-git/v5/plumbing"
17+
"github.com/google/go-github/v72/github"
1718
"github.com/shurcooL/githubv4"
1819
"github.com/smartcontractkit/chainlink-testing-framework/tools/flakeguard/utils"
1920
"golang.org/x/oauth2"
@@ -410,3 +411,247 @@ func base64EncodeFile(path string) (string, error) {
410411
}
411412
return buf.String(), nil
412413
}
414+
415+
// GitHubFileInfo represents file information from GitHub API
416+
type GitHubFileInfo struct {
417+
Name string `json:"name"`
418+
Path string `json:"path"`
419+
Type string `json:"type"` // "file" or "dir"
420+
DownloadURL string `json:"download_url"`
421+
SHA string `json:"sha"`
422+
}
423+
424+
// GitHubRepoStructure represents the analyzed repository structure
425+
type GitHubRepoStructure struct {
426+
GoModDirs []string // Directories containing go.mod files
427+
TestFiles map[string][]string // Package path -> list of test files
428+
PackageFiles map[string][]string // Package path -> list of go files
429+
}
430+
431+
// DiscoverRepoStructureViaGitHub analyzes repository structure using GitHub API
432+
// This replaces the need for local filesystem operations
433+
func DiscoverRepoStructureViaGitHub(owner, repo, ref, githubToken string) (*GitHubRepoStructure, error) {
434+
ctx := context.Background()
435+
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: githubToken})
436+
tc := oauth2.NewClient(ctx, ts)
437+
client := github.NewClient(tc)
438+
439+
structure := &GitHubRepoStructure{
440+
GoModDirs: []string{},
441+
TestFiles: make(map[string][]string),
442+
PackageFiles: make(map[string][]string),
443+
}
444+
445+
// Recursively walk the repository structure
446+
err := walkGitHubDirectory(ctx, client, owner, repo, ref, "", structure)
447+
if err != nil {
448+
return nil, fmt.Errorf("failed to walk repository structure: %w", err)
449+
}
450+
451+
return structure, nil
452+
}
453+
454+
// walkGitHubDirectory recursively walks through GitHub repository directories
455+
func walkGitHubDirectory(ctx context.Context, client *github.Client, owner, repo, ref, path string, structure *GitHubRepoStructure) error {
456+
_, directoryContent, _, err := client.Repositories.GetContents(ctx, owner, repo, path, &github.RepositoryContentGetOptions{Ref: ref})
457+
if err != nil {
458+
return fmt.Errorf("failed to get directory contents for %s: %w", path, err)
459+
}
460+
461+
var goFiles []string
462+
var testFiles []string
463+
464+
for _, content := range directoryContent {
465+
if content.GetType() == "file" {
466+
fileName := content.GetName()
467+
filePath := content.GetPath()
468+
469+
// Check for go.mod files
470+
if fileName == "go.mod" {
471+
structure.GoModDirs = append(structure.GoModDirs, path)
472+
}
473+
474+
// Check for Go files
475+
if strings.HasSuffix(fileName, ".go") {
476+
if strings.HasSuffix(fileName, "_test.go") {
477+
testFiles = append(testFiles, filePath)
478+
} else {
479+
goFiles = append(goFiles, filePath)
480+
}
481+
}
482+
} else if content.GetType() == "dir" {
483+
// Recursively walk subdirectories
484+
err := walkGitHubDirectory(ctx, client, owner, repo, ref, content.GetPath(), structure)
485+
if err != nil {
486+
return err
487+
}
488+
}
489+
}
490+
491+
// If this directory has Go files, determine the package path
492+
if len(goFiles) > 0 || len(testFiles) > 0 {
493+
// For simplicity, use the directory path as package identifier
494+
// In a real implementation, you might want to parse the package declaration
495+
packagePath := path
496+
if packagePath == "" {
497+
packagePath = "." // root package
498+
}
499+
500+
if len(goFiles) > 0 {
501+
structure.PackageFiles[packagePath] = goFiles
502+
}
503+
if len(testFiles) > 0 {
504+
structure.TestFiles[packagePath] = testFiles
505+
}
506+
}
507+
508+
return nil
509+
}
510+
511+
// GetFileContentsFromGitHub fetches file contents from GitHub
512+
func GetFileContentsFromGitHub(owner, repo, ref, path, githubToken string) (string, error) {
513+
ctx := context.Background()
514+
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: githubToken})
515+
tc := oauth2.NewClient(ctx, ts)
516+
client := github.NewClient(tc)
517+
518+
fileContent, _, _, err := client.Repositories.GetContents(ctx, owner, repo, path, &github.RepositoryContentGetOptions{Ref: ref})
519+
if err != nil {
520+
return "", fmt.Errorf("failed to get file contents for %s: %w", path, err)
521+
}
522+
523+
if fileContent == nil {
524+
return "", fmt.Errorf("file content is nil for %s", path)
525+
}
526+
527+
content, err := fileContent.GetContent()
528+
if err != nil {
529+
return "", fmt.Errorf("failed to decode file content for %s: %w", path, err)
530+
}
531+
532+
return content, nil
533+
}
534+
535+
// FindPackageForTest finds which package a test belongs to using GitHub API
536+
func FindPackageForTest(owner, repo, ref, testPackageImportPath, testName, githubToken string, structure *GitHubRepoStructure) (string, []string, error) {
537+
// Convert import path to directory path
538+
// This is a simplified approach - you might need more sophisticated logic
539+
packageDir := strings.ReplaceAll(testPackageImportPath, "/", "/")
540+
541+
// Find test files in the package directory
542+
testFiles, exists := structure.TestFiles[packageDir]
543+
if !exists {
544+
// Try alternative mappings or search through all test files
545+
for dir, files := range structure.TestFiles {
546+
// Check if this directory might contain the package we're looking for
547+
if strings.Contains(dir, packageDir) || strings.HasSuffix(testPackageImportPath, filepath.Base(dir)) {
548+
testFiles = files
549+
packageDir = dir
550+
break
551+
}
552+
}
553+
}
554+
555+
if len(testFiles) == 0 {
556+
return "", nil, fmt.Errorf("no test files found for package %s", testPackageImportPath)
557+
}
558+
559+
return packageDir, testFiles, nil
560+
}
561+
562+
// CreateBranchOnGitHub creates a new branch on GitHub from the base branch
563+
func CreateBranchOnGitHub(owner, repo, branchName, baseBranch, githubToken string) error {
564+
ctx := context.Background()
565+
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: githubToken})
566+
tc := oauth2.NewClient(ctx, ts)
567+
client := github.NewClient(tc)
568+
569+
// Get the reference of the base branch
570+
baseRef, _, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+baseBranch)
571+
if err != nil {
572+
return fmt.Errorf("failed to get base branch reference: %w", err)
573+
}
574+
575+
// Create new branch reference
576+
newRef := &github.Reference{
577+
Ref: github.Ptr("refs/heads/" + branchName),
578+
Object: &github.GitObject{
579+
SHA: baseRef.Object.SHA,
580+
},
581+
}
582+
583+
_, _, err = client.Git.CreateRef(ctx, owner, repo, newRef)
584+
if err != nil {
585+
return fmt.Errorf("failed to create branch: %w", err)
586+
}
587+
588+
return nil
589+
}
590+
591+
// CommitFilesToGitHub commits multiple files to a GitHub branch using GraphQL API
592+
func CommitFilesToGitHub(owner, repo, branchName string, files map[string]string, commitMsg, githubToken string) (string, error) {
593+
tok := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: githubToken})
594+
token := oauth2.NewClient(context.Background(), tok)
595+
graphqlClient := githubv4.NewClient(token)
596+
597+
// Prepare file additions
598+
additions := []githubv4.FileAddition{}
599+
for filePath, content := range files {
600+
// Base64 encode the content
601+
encoded := base64.StdEncoding.EncodeToString([]byte(content))
602+
additions = append(additions, githubv4.FileAddition{
603+
Path: githubv4.String(filePath),
604+
Contents: githubv4.Base64String(encoded),
605+
})
606+
}
607+
608+
var m struct {
609+
CreateCommitOnBranch struct {
610+
Commit struct {
611+
URL string `graphql:"url"`
612+
OID string `graphql:"oid"`
613+
}
614+
} `graphql:"createCommitOnBranch(input:$input)"`
615+
}
616+
617+
splitMsg := strings.SplitN(commitMsg, "\n", 2)
618+
headline := splitMsg[0]
619+
body := ""
620+
if len(splitMsg) > 1 {
621+
body = splitMsg[1]
622+
}
623+
624+
// Get current HEAD SHA of the branch
625+
ctx := context.Background()
626+
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: githubToken})
627+
tc := oauth2.NewClient(ctx, ts)
628+
restClient := github.NewClient(tc)
629+
630+
branchRef, _, err := restClient.Git.GetRef(ctx, owner, repo, "refs/heads/"+branchName)
631+
if err != nil {
632+
return "", fmt.Errorf("failed to get branch reference: %w", err)
633+
}
634+
expectedHeadOid := branchRef.Object.GetSHA()
635+
636+
// Create the GraphQL input
637+
input := githubv4.CreateCommitOnBranchInput{
638+
Branch: githubv4.CommittableBranch{
639+
RepositoryNameWithOwner: githubv4.NewString(githubv4.String(fmt.Sprintf("%s/%s", owner, repo))),
640+
BranchName: githubv4.NewString(githubv4.String(branchName)),
641+
},
642+
Message: githubv4.CommitMessage{
643+
Headline: githubv4.String(headline),
644+
Body: githubv4.NewString(githubv4.String(body)),
645+
},
646+
FileChanges: &githubv4.FileChanges{
647+
Additions: &additions,
648+
},
649+
ExpectedHeadOid: githubv4.GitObjectID(expectedHeadOid),
650+
}
651+
652+
if err := graphqlClient.Mutate(context.Background(), &m, input, nil); err != nil {
653+
return "", fmt.Errorf("failed to create commit: %w", err)
654+
}
655+
656+
return m.CreateCommitOnBranch.Commit.OID, nil
657+
}

0 commit comments

Comments
 (0)