Skip to content

Commit 6952bc2

Browse files
feat: Implement plugin RPC handler (#6)
* feat: Implement plugin RPC handler * make rpc package internal
1 parent ee94c71 commit 6952bc2

File tree

3 files changed

+157
-36
lines changed

3 files changed

+157
-36
lines changed

codegen/codegen.go

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,14 @@
11
package codegen
22

33
import (
4-
"bufio"
54
"context"
6-
"fmt"
7-
"io"
8-
"os"
95

6+
"github.com/sqlc-dev/sqlc-go/internal/rpc"
107
pb "github.com/sqlc-dev/sqlc-go/plugin"
11-
"google.golang.org/protobuf/proto"
128
)
139

1410
type Handler func(context.Context, *pb.GenerateRequest) (*pb.GenerateResponse, error)
1511

1612
func Run(h Handler) {
17-
if err := run(h); err != nil {
18-
fmt.Fprintf(os.Stderr, "error generating output: %s", err)
19-
os.Exit(2)
20-
}
21-
}
22-
23-
func run(h Handler) error {
24-
var req pb.GenerateRequest
25-
reqBlob, err := io.ReadAll(os.Stdin)
26-
if err != nil {
27-
return err
28-
}
29-
if err := proto.Unmarshal(reqBlob, &req); err != nil {
30-
return err
31-
}
32-
resp, err := h(context.Background(), &req)
33-
if err != nil {
34-
return err
35-
}
36-
respBlob, err := proto.Marshal(resp)
37-
if err != nil {
38-
return err
39-
}
40-
w := bufio.NewWriter(os.Stdout)
41-
if _, err := w.Write(respBlob); err != nil {
42-
return err
43-
}
44-
if err := w.Flush(); err != nil {
45-
return err
46-
}
47-
return nil
13+
rpc.Handle(&server{handler: h})
4814
}

codegen/server.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package codegen
2+
3+
import (
4+
"context"
5+
6+
pb "github.com/sqlc-dev/sqlc-go/plugin"
7+
)
8+
9+
type server struct {
10+
pb.UnimplementedCodegenServiceServer
11+
12+
handler Handler
13+
}
14+
15+
func (s *server) Generate(ctx context.Context, req *pb.GenerateRequest) (*pb.GenerateResponse, error) {
16+
return s.handler(ctx, req)
17+
}

internal/rpc/handler.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package rpc
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"fmt"
7+
"io"
8+
"os"
9+
"strings"
10+
11+
"google.golang.org/grpc"
12+
"google.golang.org/grpc/codes"
13+
"google.golang.org/grpc/status"
14+
"google.golang.org/protobuf/proto"
15+
"google.golang.org/protobuf/reflect/protoreflect"
16+
17+
"github.com/sqlc-dev/sqlc-go/plugin"
18+
pb "github.com/sqlc-dev/sqlc-go/plugin"
19+
)
20+
21+
func Handle(server pb.CodegenServiceServer) {
22+
if err := handle(server); err != nil {
23+
fmt.Fprintf(os.Stderr, "error generating output: %s", err)
24+
os.Exit(2)
25+
}
26+
}
27+
28+
func handle(server pb.CodegenServiceServer) error {
29+
handler := newStdioRPCHandler()
30+
pb.RegisterCodegenServiceServer(handler, server)
31+
return handler.Handle()
32+
}
33+
34+
type stdioRPCHandler struct {
35+
services map[string]*serviceInfo
36+
}
37+
38+
func newStdioRPCHandler() *stdioRPCHandler {
39+
return &stdioRPCHandler{services: map[string]*serviceInfo{}}
40+
}
41+
42+
type serviceInfo struct {
43+
serviceImpl any
44+
methods map[string]*grpc.MethodDesc
45+
}
46+
47+
func (s *stdioRPCHandler) RegisterService(sd *grpc.ServiceDesc, ss any) {
48+
// TODO some type checking, see e.g. grpc server.RegisterService()
49+
info := &serviceInfo{
50+
serviceImpl: ss,
51+
methods: make(map[string]*grpc.MethodDesc),
52+
}
53+
for i := range sd.Methods {
54+
d := &sd.Methods[i]
55+
info.methods[d.MethodName] = d
56+
}
57+
s.services[sd.ServiceName] = info
58+
}
59+
60+
func (s *stdioRPCHandler) Handle() error {
61+
var methodArg string
62+
if len(os.Args) < 2 {
63+
// For backwards compatibility with sqlc before v1.24.0
64+
methodArg = fmt.Sprintf("/%s/%s", pb.CodegenService_ServiceDesc.ServiceName, "Generate")
65+
} else {
66+
methodArg = os.Args[1]
67+
}
68+
69+
// Adapted from grpc server handleStream()
70+
71+
sm := methodArg
72+
if sm != "" && sm[0] == '/' {
73+
sm = sm[1:]
74+
}
75+
pos := strings.LastIndex(sm, "/")
76+
if pos == -1 {
77+
errDesc := fmt.Sprintf("malformed method name: %q", methodArg)
78+
return status.Error(codes.Unimplemented, errDesc)
79+
}
80+
service := sm[:pos]
81+
method := sm[pos+1:]
82+
83+
srv, knownService := s.services[service]
84+
if knownService {
85+
if md, ok := srv.methods[method]; ok {
86+
return s.processUnaryRPC(srv, md)
87+
}
88+
}
89+
90+
// Unknown service, or known server unknown method.
91+
var errDesc string
92+
if !knownService {
93+
errDesc = fmt.Sprintf("unknown service %v", service)
94+
} else {
95+
errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
96+
}
97+
98+
return status.Error(codes.Unimplemented, errDesc)
99+
}
100+
101+
func (s *stdioRPCHandler) processUnaryRPC(srv *serviceInfo, md *grpc.MethodDesc) error {
102+
reqBytes, err := io.ReadAll(os.Stdin)
103+
if err != nil {
104+
return err
105+
}
106+
107+
var resp protoreflect.ProtoMessage
108+
109+
// TODO make this generic
110+
switch md.MethodName {
111+
case "Generate":
112+
var req plugin.GenerateRequest
113+
if err := proto.Unmarshal(reqBytes, &req); err != nil {
114+
return err
115+
}
116+
service, ok := srv.serviceImpl.(pb.CodegenServiceServer)
117+
if !ok {
118+
return status.Errorf(codes.Internal, codes.Internal.String())
119+
}
120+
resp, err = service.Generate(context.Background(), &req)
121+
if err != nil {
122+
return err
123+
}
124+
}
125+
126+
respBytes, err := proto.Marshal(resp)
127+
if err != nil {
128+
return err
129+
}
130+
w := bufio.NewWriter(os.Stdout)
131+
if _, err := w.Write(respBytes); err != nil {
132+
return err
133+
}
134+
if err := w.Flush(); err != nil {
135+
return err
136+
}
137+
return nil
138+
}

0 commit comments

Comments
 (0)