Skip to content

Commit c982020

Browse files
committed
Add list of allowed modules
1 parent d134d37 commit c982020

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

cmd/server/server.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"log"
1010
"net"
1111
"os"
12+
"slices"
13+
"strings"
1214

1315
"github.com/miekg/pkcs11"
1416
p11 "github.com/ryarnyah/pkcs11-go-proxy/pkcs11"
@@ -19,16 +21,23 @@ import (
1921
)
2022

2123
// ErrCtxNotFound raised when context can't be found.
22-
var ErrCtxNotFound = errors.New("Context not found")
24+
var ErrCtxNotFound = errors.New("context not found")
25+
26+
// ErrModuleNotAllowed raised when module is not allowlist.
27+
var ErrModuleNotAllowed = errors.New("module not allowed")
2328

2429
type pkcs11Server struct {
2530
ctxs map[string]*pkcs11.Ctx
2631

32+
allowedModules []string
2733
p11.UnimplementedPKCS11Server
2834
}
2935

3036
// New creates a new context and initializes the module/library for use.
3137
func (m *pkcs11Server) New(ctx context.Context, in *p11.NewRequest) (*p11.NewResponse, error) {
38+
if len(m.allowedModules) > 0 && !slices.Contains(m.allowedModules, in.GetModule()) {
39+
return nil, ErrModuleNotAllowed
40+
}
3241
c, ok := m.ctxs[in.GetModule()]
3342
if ok {
3443
c.Finalize()
@@ -1021,6 +1030,11 @@ func main() {
10211030
c = credentials.NewTLS(tlsConfig)
10221031
}
10231032

1033+
allowedModules := []string{}
1034+
if os.Getenv("PKCS11_PROXY_ALLOWED_MODULES") != "" {
1035+
allowedModules = strings.Split(os.Getenv("PKCS11_PROXY_ALLOWED_MODULES"), ";")
1036+
}
1037+
10241038
errHandler := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
10251039
resp, err := handler(ctx, req)
10261040
if err != nil {
@@ -1030,7 +1044,8 @@ func main() {
10301044
}
10311045
s := grpc.NewServer(grpc.Creds(c), grpc.UnaryInterceptor(errHandler))
10321046
server := &pkcs11Server{
1033-
ctxs: make(map[string]*pkcs11.Ctx, 0),
1047+
ctxs: make(map[string]*pkcs11.Ctx, 0),
1048+
allowedModules: allowedModules,
10341049
}
10351050
p11.RegisterPKCS11Server(s, server)
10361051
if err := s.Serve(listener); err != nil {

0 commit comments

Comments
 (0)