Skip to content

Commit 120b81b

Browse files
committed
🥧 🍮 🌶️ impl: query extractor
1 parent e0ecb41 commit 120b81b

File tree

7 files changed

+453
-0
lines changed

7 files changed

+453
-0
lines changed

cli/extractor.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path/filepath"
7+
8+
"github.com/h24w-17/extractor"
9+
"github.com/spf13/cobra"
10+
)
11+
12+
var rootCmd = &cobra.Command{
13+
Use: "query-extractor",
14+
Long: "Statistically analyze the codebase and extract SQL queries",
15+
Args: cobra.ExactArgs(1),
16+
ValidArgs: []string{"path"},
17+
RunE: func(cmd *cobra.Command, args []string) error {
18+
path := args[0]
19+
20+
out, err := cmd.Flags().GetString("out")
21+
if err != nil {
22+
return fmt.Errorf("error getting out flag: %v", err)
23+
}
24+
25+
valid := extractor.IsValidDir(path)
26+
if !valid {
27+
return fmt.Errorf("invalid directory: %s", path)
28+
}
29+
files, err := extractor.ListAllGoFiles(path)
30+
if err != nil {
31+
return fmt.Errorf("error listing go files: %v", err)
32+
}
33+
34+
fmt.Printf("found %d go files\n", len(files))
35+
extractedQueries := []*extractor.ExtractedQuery{}
36+
for _, file := range files {
37+
relativePath, err := filepath.Rel(path, file)
38+
if err != nil {
39+
return fmt.Errorf("error getting relative path: %v", err)
40+
}
41+
queries, err := extractor.ExtractQueryFromFile(file, path)
42+
if err != nil {
43+
return fmt.Errorf("❌ %s: error while extracting: %v", relativePath, err)
44+
}
45+
fmt.Printf("✅ %s: %d queries extracted\n", relativePath, len(queries))
46+
extractedQueries = append(extractedQueries, queries...)
47+
}
48+
49+
err = extractor.WriteQueriesToFile(out, extractedQueries)
50+
if err != nil {
51+
return fmt.Errorf("error writing queries to file: %v", err)
52+
}
53+
54+
fmt.Printf("queries written to %s\n", out)
55+
56+
return nil
57+
},
58+
}
59+
60+
func main() {
61+
err := rootCmd.Execute()
62+
if err != nil {
63+
os.Exit(1)
64+
}
65+
}
66+
67+
func init() {
68+
rootCmd.Flags().StringP("out", "o", "extracted.sql", "Destination file that extracted queries will be written to")
69+
}

extractor/extractor.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package extractor
2+
3+
import (
4+
"fmt"
5+
"go/ast"
6+
"go/parser"
7+
"go/token"
8+
"path/filepath"
9+
"regexp"
10+
"strings"
11+
)
12+
13+
var sqlPattern = regexp.MustCompile(`(?i)\b(SELECT|INSERT|UPDATE|DELETE)\b`)
14+
var replacePattern = regexp.MustCompile(`\s+`)
15+
16+
type ExtractedQuery struct {
17+
file string
18+
pos int
19+
content string
20+
}
21+
22+
func ExtractQueryFromFile(path string, root string) ([]*ExtractedQuery, error) {
23+
fs := token.NewFileSet()
24+
node, err := parser.ParseFile(fs, path, nil, parser.AllErrors)
25+
if err != nil {
26+
return nil, fmt.Errorf("error parsing file: %v", err)
27+
}
28+
29+
// 結果を収集
30+
var results []*ExtractedQuery
31+
ast.Inspect(node, func(n ast.Node) bool {
32+
if n == nil {
33+
return false
34+
}
35+
// 文字列リテラルを抽出
36+
if lit, ok := n.(*ast.BasicLit); ok && lit.Kind == token.STRING {
37+
// SQLクエリらしき文字列を抽出
38+
value := strings.Trim(lit.Value, "\"`")
39+
value = strings.ReplaceAll(value, "\n", " ")
40+
value = replacePattern.ReplaceAllString(value, " ")
41+
pos := lit.Pos()
42+
if sqlPattern.MatchString(value) {
43+
pos := fs.Position(pos)
44+
relativePath, err := filepath.Rel(root, path)
45+
if err != nil {
46+
return false
47+
}
48+
results = append(results, &ExtractedQuery{
49+
file: relativePath,
50+
pos: pos.Line,
51+
content: value,
52+
})
53+
}
54+
return false
55+
}
56+
return true
57+
})
58+
59+
return results, nil
60+
}

extractor/extractor_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package extractor
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestExtractQueryFromFile(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
path string
13+
root string
14+
expected []*ExtractedQuery
15+
}{
16+
{
17+
name: "extract query from file",
18+
path: "testdata/extractor1.go",
19+
root: "testdata",
20+
expected: []*ExtractedQuery{
21+
{file: "extractor1.go", pos: 32, content: "SELECT id, name FROM users"},
22+
{file: "extractor1.go", pos: 44, content: "INSERT INTO users (name) VALUES (?)"},
23+
{file: "extractor1.go", pos: 59, content: "SELECT id, name FROM users WHERE id = ?"},
24+
{file: "extractor1.go", pos: 72, content: "UPDATE users SET name = ? WHERE id = ?"},
25+
{file: "extractor1.go", pos: 79, content: "DELETE FROM users WHERE id = ?"},
26+
{file: "extractor1.go", pos: 89, content: "SELECT id, user_id, title, body FROM posts"},
27+
{file: "extractor1.go", pos: 103, content: "INSERT INTO posts (user_id, title, body) VALUES (?, ?, ?)"},
28+
{file: "extractor1.go", pos: 118, content: "SELECT id, user_id, title, body FROM posts WHERE id = ?"},
29+
{file: "extractor1.go", pos: 133, content: "UPDATE posts SET user_id = ?, title = ?, body = ? WHERE id = ?"},
30+
{file: "extractor1.go", pos: 140, content: "DELETE FROM posts WHERE id = ?"},
31+
{file: "extractor1.go", pos: 150, content: "SELECT id, post_id, body FROM comments"},
32+
{file: "extractor1.go", pos: 163, content: "INSERT INTO comments (post_id, body) VALUES (?, ?)"},
33+
{file: "extractor1.go", pos: 178, content: "SELECT id, post_id, body FROM comments WHERE id = ?"},
34+
{file: "extractor1.go", pos: 191, content: "UPDATE comments SET body = ? WHERE id = ?"},
35+
{file: "extractor1.go", pos: 198, content: "DELETE FROM comments WHERE id = ?"},
36+
},
37+
},
38+
}
39+
40+
for _, test := range tests {
41+
t.Run(test.name, func(t *testing.T) {
42+
actual, err := ExtractQueryFromFile(test.path, test.root)
43+
assert.NoError(t, err)
44+
assert.Equal(t, test.expected, actual)
45+
})
46+
}
47+
}

extractor/io.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package extractor
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path/filepath"
7+
"strings"
8+
)
9+
10+
func IsValidDir(dir string) bool {
11+
info, err := os.Stat(dir)
12+
if err != nil {
13+
return false
14+
}
15+
return info.IsDir()
16+
}
17+
18+
func ListAllGoFiles(dir string) ([]string, error) {
19+
var files []string
20+
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
21+
if err != nil {
22+
return err
23+
}
24+
if info.IsDir() {
25+
return nil
26+
}
27+
if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") {
28+
return nil
29+
}
30+
files = append(files, path)
31+
return nil
32+
})
33+
if err != nil {
34+
return []string{}, fmt.Errorf("error walking directory: %v", err)
35+
}
36+
return files, nil
37+
}
38+
39+
func WriteQueriesToFile(out string, queries []*ExtractedQuery) error {
40+
f, err := os.Create(out)
41+
if err != nil {
42+
return fmt.Errorf("error creating file: %v", err)
43+
}
44+
defer f.Close()
45+
46+
for _, query := range queries {
47+
_, err := f.WriteString(query.String() + "\n")
48+
if err != nil {
49+
return fmt.Errorf("error writing to file: %v", err)
50+
}
51+
}
52+
53+
return nil
54+
}
55+
56+
func (q *ExtractedQuery) String() string {
57+
return fmt.Sprintf("-- %s:%d\n%s", q.file, q.pos, q.content)
58+
}

0 commit comments

Comments
 (0)