diff --git a/cmd/app/main.go b/cmd/app/main.go index c613fe4..d56e922 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -25,7 +25,7 @@ import ( ) var ( - port = flag.Int("port", 9090, "TCP port for HTTP server to bind") + addr = flag.String("addr", "0.0.0.0:9090", "IP and TCP port for HTTP server to bind") serverID = flag.String("serverID", uuid.New().String(), "ID to identify the server. Must be globally unique within the cluster") discoveryType = flag.String("discoveryType", virtual.DiscoveryTypeLocalHost, "how the server should register itself with the discovery serice. Valid options: localhost|remote. Use localhost for local testing, use remote for multi-node setups") registryType = flag.String("registryBackend", "memory", "backend to use for the Registry. Validation options: memory|foundationdb") @@ -33,6 +33,8 @@ var ( shutdownTimeout = flag.Duration("shutdownTimeout", 0, "timeout until the server is forced to shutdown, without waiting actors and other components to close gracefully. By default is 0, which is infinite duration untill all actors are closed") logFormat = flag.String("logFormat", "text", "format to use for the logger. The formats it accepst are: 'text', 'json'") logLevel = flag.String("logLevel", "debug", "level to use for the logger. The levels it accepts are: 'info', 'debug', 'error', 'warn'") + websocketsEnabled = flag.Bool("websockets", false, "enable websockets endpoint") + websocketsAddr = flag.String("websocketsAddr", "0.0.0.0:9092", "websockets server address") ) func main() { @@ -68,11 +70,17 @@ func main() { client := virtual.NewHTTPClient() + port, err := utils.ParsePortFromAddr(*addr) + if err != nil { + log.Error("failed to parse addr", slog.Any("error", err), slog.String("addr", *addr)) + os.Exit(1) + } + ctx, cc := context.WithTimeout(context.Background(), 10*time.Second) environment, err := virtual.NewEnvironment(ctx, *serverID, reg, client, virtual.EnvironmentOptions{ Discovery: virtual.DiscoveryOptions{ DiscoveryType: *discoveryType, - Port: *port, + Port: port, }, Logger: log, }) @@ -82,9 +90,7 @@ func main() { os.Exit(1) } - var server virtualServer = virtual.NewServer(reg, environment) - - log.Info("server listening", slog.Int("port", *port)) + var server virtualServer = virtual.NewServer(log, reg, environment) go func(server virtualServer) { sig := waitForSignal() @@ -92,7 +98,18 @@ func main() { shutdown(log, server, *shutdownTimeout) }(server) - if err := server.Start(*port); err != nil && !errors.Is(err, http.ErrServerClosed) { + if *websocketsEnabled { + go func() { + log.Info("ws server listening", slog.String("addr", *websocketsAddr)) + if err := server.StartWebsocket(*websocketsAddr); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Error("received error", slog.Any("error", err), slog.String("subService", "httpWsServer")) + shutdown(log, server, *shutdownTimeout) + } + }() + } + + log.Info("http server listening", slog.String("addr", *addr)) + if err := server.Start(*addr); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Error("received error", slog.Any("error", err), slog.String("subService", "httpServer")) shutdown(log, server, *shutdownTimeout) os.Exit(1) @@ -100,7 +117,8 @@ func main() { } type virtualServer interface { - Start(int) error + Start(string) error + StartWebsocket(string) error Stop(context.Context) error } diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index d6257e3..567223d 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -2,7 +2,9 @@ package utils import ( "fmt" + "net" "os" + "strconv" "golang.org/x/exp/slog" ) @@ -31,3 +33,17 @@ func ParseLog(logLevel, logFormat string) (*slog.Logger, error) { return nil, fmt.Errorf("invalid log format: %s", logFormat) } } + +func ParsePortFromAddr(addr string) (int, error) { + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 0, err + } + + return strconv.Atoi(portStr) +} + +func ParseHostFromAddr(addr string) (string, error) { + host, _, err := net.SplitHostPort(addr) + return host, err +} diff --git a/examples/dnsregistry/main.go b/examples/dnsregistry/main.go index 2de03d6..851870a 100644 --- a/examples/dnsregistry/main.go +++ b/examples/dnsregistry/main.go @@ -17,8 +17,7 @@ import ( ) var ( - host = flag.String("host", "localhost", "Hostname to perform DNS lookups against") - port = flag.Int("port", 9090, "TCP port for HTTP server to bind") + addr = flag.String("addr", "localhost:9090", "IP and TCP port for HTTP server to bind") logFormat = flag.String("logFormat", "text", "format to use for the logger. The formats it accepst are: 'text', 'json'") logLevel = flag.String("logLevel", "debug", "level to use for the logger. The levels it accepts are: 'info', 'debug', 'error', 'warn'") ) @@ -26,20 +25,26 @@ var ( func main() { flag.Parse() - if *host == "" { - flag.Usage() - slog.Error("host cannot be empty") + log, err := utils.ParseLog(*logLevel, *logFormat) + if err != nil { + slog.Error("failed to parse log", slog.Any("error", err)) os.Exit(1) } - log, err := utils.ParseLog(*logLevel, *logFormat) + port, err := utils.ParsePortFromAddr(*addr) if err != nil { - slog.Error("failed to parse log", slog.Any("error", err)) + log.Error("failed to parse port from addr", slog.Any("error", err), slog.String("addr", *addr)) + os.Exit(1) + } + + host, err := utils.ParseHostFromAddr(*addr) + if err != nil { + log.Error("failed to parse host from addr", slog.Any("error", err), slog.String("addr", *addr)) os.Exit(1) } env, registry, err := virtual.NewDNSRegistryEnvironment( - context.Background(), *host, *port, virtual.EnvironmentOptions{Logger: log}) + context.Background(), host, port, virtual.EnvironmentOptions{Logger: log}) if err != nil { log.Error("error creating virtual environment", slog.Any("error", err)) os.Exit(1) @@ -78,8 +83,8 @@ func main() { } }() - server := virtual.NewServer(registry, env) - if err := server.Start(*port); err != nil { + server := virtual.NewServer(log, registry, env) + if err := server.Start(*addr); err != nil { log.Error("error starting server", slog.Any("error", err)) os.Exit(1) } diff --git a/examples/file_cache/benchmark_test.go b/examples/file_cache/benchmark_test.go index fe63cf5..c3fcaf2 100644 --- a/examples/file_cache/benchmark_test.go +++ b/examples/file_cache/benchmark_test.go @@ -13,6 +13,7 @@ import ( "github.com/richardartoul/nola/virtual" "github.com/richardartoul/nola/virtual/registry/localregistry" "github.com/richardartoul/nola/virtual/types" + "golang.org/x/exp/slog" "github.com/DataDog/sketches-go/ddsketch" "github.com/stretchr/testify/require" @@ -56,9 +57,9 @@ func TestFileCacheBenchmark(t *testing.T) { NewFileCacheModule(chunkSize, fetchSize, fetcher, cache)) require.NoError(t, err) - server := virtual.NewServer(registry, env) + server := virtual.NewServer(slog.Default(), registry, env) go func() { - if err := server.Start(port); err != nil { + if err := server.Start(fmt.Sprintf("0.0.0.0:%d", port)); err != nil { panic(err) } }() diff --git a/go.mod b/go.mod index a9b9a11..98ab920 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/wasmerio/wasmer-go v1.0.4 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 golang.org/x/sync v0.1.0 + nhooyr.io/websocket v1.8.7 ) require ( @@ -22,6 +23,8 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect + github.com/gorilla/websocket v1.4.2 // indirect + github.com/klauspost/compress v1.10.3 // indirect github.com/kr/text v0.2.0 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum index 0199bb1..3163ed3 100644 --- a/go.sum +++ b/go.sum @@ -17,21 +17,57 @@ github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczC github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= @@ -42,6 +78,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -53,6 +90,10 @@ github.com/tetratelabs/wazero v1.0.0-pre.6 h1:3DRqjuHazHyZmgWCgqu7nKgYIYNEi2+2RQ github.com/tetratelabs/wazero v1.0.0-pre.6/go.mod h1:u8wrFmpdrykiFK0DFPiFm5a4+0RzsdmXYVtijBKqUVo= github.com/tinylib/msgp v1.1.5/go.mod h1:eQsjooMTnV42mHu917E26IogZ2930nFyBQdofk10Udg= github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/wapc/wapc-go v0.5.7 h1:ZPswSRFlg7JLyanvVndIY9YWJCONcVO8Zs+7pjsIQyA= github.com/wapc/wapc-go v0.5.7/go.mod h1:7+O5cEJaLqhnwE0Trrx9PceBpCNzMx2fNtyBBPseucY= github.com/wapc/wapc-guest-tinygo v0.3.3 h1:jLebiwjVSHLGnS+BRabQ6+XOV7oihVWAc05Hf1SbeR0= @@ -75,12 +116,15 @@ golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -96,6 +140,10 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= +nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= diff --git a/virtual/client.go b/virtual/client.go index 9022313..ccf3eec 100644 --- a/virtual/client.go +++ b/virtual/client.go @@ -26,7 +26,7 @@ func (h *httpClient) InvokeActorRemote( payload []byte, create types.CreateIfNotExist, ) (io.ReadCloser, error) { - ir := invokeActorDirectRequest{ + ir := types.InvokeActorDirectHttpRequest{ VersionStamp: versionStamp, ServerID: reference.ServerID(), ServerVersion: reference.ServerVersion(), diff --git a/virtual/server.go b/virtual/server.go index 25bd11c..414f706 100644 --- a/virtual/server.go +++ b/virtual/server.go @@ -13,29 +13,35 @@ import ( "github.com/richardartoul/nola/virtual/registry" "github.com/richardartoul/nola/virtual/types" + "golang.org/x/exp/slog" ) type server struct { + logger *slog.Logger + // Dependencies. registry registry.Registry environment Environment - server *http.Server + server *http.Server + wsServer *http.Server } // NewServer creates a new server for the actor virtual environment. func NewServer( + logger *slog.Logger, registry registry.Registry, environment Environment, ) *server { return &server{ + logger: logger.With(slog.String("module", "server")), registry: registry, environment: environment, } } // Start starts the server. -func (s *server) Start(port int) error { +func (s *server) Start(addr string) error { mux := http.NewServeMux() mux.HandleFunc("/api/v1/register-module", s.registerModule) mux.HandleFunc("/api/v1/invoke-actor", s.invoke) @@ -43,15 +49,24 @@ func (s *server) Start(port int) error { mux.HandleFunc("/api/v1/invoke-worker", s.invokeWorker) s.server = &http.Server{ - Addr: fmt.Sprintf(":%d", port), + Addr: addr, Handler: mux, } - if err := s.server.ListenAndServe(); err != nil { - return err + return s.server.ListenAndServe() +} + +// Start starts the server. +func (s *server) StartWebsocket(addr string) error { + mux := http.NewServeMux() + mux.HandleFunc("/api/v1/rpc/json", s.wsHandler) + + s.wsServer = &http.Server{ + Addr: addr, + Handler: mux, } - return nil + return s.wsServer.ListenAndServe() } func (s *server) Stop(ctx context.Context) error { @@ -61,6 +76,14 @@ func (s *server) Stop(ctx context.Context) error { } log.Print("successfully shut down HTTP server") + if s.wsServer != nil { + log.Print("shutting down websocket http server") + if err := s.wsServer.Shutdown(ctx); err != nil { + return fmt.Errorf("failed to shut down http server: %w", err) + } + log.Print("successfully shut down websocket HTTP server") + } + log.Print("closing environment") if err := s.environment.Close(ctx); err != nil { return fmt.Errorf("failed to close the environment: %w", err) @@ -86,9 +109,7 @@ func (s *server) registerModule(w http.ResponseWriter, r *http.Request) { return } - ctx, cc := context.WithTimeout(context.Background(), 60*time.Second) - defer cc() - result, err := s.registry.RegisterModule(ctx, namespace, moduleID, moduleBytes, registry.ModuleOptions{}) + result, err := s.handleRegisterModule(r.Context(), types.RegisterModuleHttpRequest{Namespace: namespace, ModuleID: moduleID, ModuleBytes: moduleBytes}) if err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -106,13 +127,10 @@ func (s *server) registerModule(w http.ResponseWriter, r *http.Request) { w.Write(marshaled) } -type invokeActorRequest struct { - ServerID string `json:"server_id"` - Namespace string `json:"namespace"` - types.InvokeActorRequest - // Same data as Payload (in types.InvokeActorRequest), but different field so it doesn't - // have to be encoded as base64. - PayloadJSON interface{} `json:"payload_json"` +func (s *server) handleRegisterModule(ctx context.Context, req types.RegisterModuleHttpRequest) (registry.RegisterModuleResult, error) { + ctx, cc := context.WithTimeout(ctx, 60*time.Second) + defer cc() + return s.registry.RegisterModule(ctx, req.Namespace, req.ModuleID, req.ModuleBytes, registry.ModuleOptions{}) } func (s *server) invoke(w http.ResponseWriter, r *http.Request) { @@ -128,7 +146,7 @@ func (s *server) invoke(w http.ResponseWriter, r *http.Request) { return } - var req invokeActorRequest + var req types.InvokeActorHttpRequest if err := json.Unmarshal(jsonBytes, &req); err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -146,10 +164,7 @@ func (s *server) invoke(w http.ResponseWriter, r *http.Request) { } // TODO: This should be configurable, probably in a header with some maximum. - ctx, cc := context.WithTimeout(context.Background(), 5*time.Second) - defer cc() - result, err := s.environment.InvokeActorStream( - ctx, req.Namespace, req.ActorID, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) + result, err := s.handleInvoke(r.Context(), req) if err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -167,17 +182,11 @@ func (s *server) invoke(w http.ResponseWriter, r *http.Request) { } } -type invokeActorDirectRequest struct { - VersionStamp int64 `json:"version_stamp"` - ServerID string `json:"server_id"` - ServerVersion int64 `json:"server_version"` - Namespace string `json:"namespace"` - ModuleID string `json:"module_id"` - ActorID string `json:"actor_id"` - Generation uint64 `json:"generation"` - Operation string `json:"operation"` - Payload []byte `json:"payload"` - CreateIfNotExist types.CreateIfNotExist `json:"create_if_not_exist"` +func (s *server) handleInvoke(ctx context.Context, req types.InvokeActorHttpRequest) (io.ReadCloser, error) { + ctx, cc := context.WithTimeout(ctx, 5*time.Second) + defer cc() + return s.environment.InvokeActorStream( + ctx, req.Namespace, req.ActorID, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) } func (s *server) invokeDirect(w http.ResponseWriter, r *http.Request) { @@ -193,7 +202,7 @@ func (s *server) invokeDirect(w http.ResponseWriter, r *http.Request) { return } - var req invokeActorDirectRequest + var req types.InvokeActorDirectHttpRequest if err := json.Unmarshal(jsonBytes, &req); err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -201,19 +210,7 @@ func (s *server) invokeDirect(w http.ResponseWriter, r *http.Request) { } // TODO: This should be configurable, probably in a header with some maximum. - ctx, cc := context.WithTimeout(context.Background(), 5*time.Second) - defer cc() - - ref, err := types.NewVirtualActorReference(req.Namespace, req.ModuleID, req.ActorID, uint64(req.Generation)) - if err != nil { - w.WriteHeader(500) - w.Write([]byte(err.Error())) - return - } - - result, err := s.environment.InvokeActorDirectStream( - ctx, req.VersionStamp, req.ServerID, req.ServerVersion, ref, - req.Operation, req.Payload, req.CreateIfNotExist) + result, err := s.handleInvokeDirect(r.Context(), req) if err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -231,14 +228,18 @@ func (s *server) invokeDirect(w http.ResponseWriter, r *http.Request) { } } -type invokeWorkerRequest struct { - Namespace string `json:"namespace"` - // TODO: Allow ModuleID to be omitted if the caller provides a WASMExecutable field which contains the - // actual WASM program that should be executed. - ModuleID string `json:"module_id"` - Operation string `json:"operation"` - Payload []byte `json:"payload"` - CreateIfNotExist types.CreateIfNotExist `json:"create_if_not_exist"` +func (s *server) handleInvokeDirect(ctx context.Context, req types.InvokeActorDirectHttpRequest) (io.ReadCloser, error) { + ctx, cc := context.WithTimeout(ctx, 5*time.Second) + defer cc() + + ref, err := types.NewVirtualActorReference(req.Namespace, req.ModuleID, req.ActorID, uint64(req.Generation)) + if err != nil { + return nil, err + } + + return s.environment.InvokeActorDirectStream( + ctx, req.VersionStamp, req.ServerID, req.ServerVersion, ref, + req.Operation, req.Payload, req.CreateIfNotExist) } func (s *server) invokeWorker(w http.ResponseWriter, r *http.Request) { @@ -254,7 +255,7 @@ func (s *server) invokeWorker(w http.ResponseWriter, r *http.Request) { return } - var req invokeWorkerRequest + var req types.InvokeWorkerHttpRequest if err := json.Unmarshal(jsonBytes, &req); err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -262,11 +263,7 @@ func (s *server) invokeWorker(w http.ResponseWriter, r *http.Request) { } // TODO: This should be configurable, probably in a header with some maximum. - ctx, cc := context.WithTimeout(context.Background(), 5*time.Second) - defer cc() - - result, err := s.environment.InvokeWorkerStream( - ctx, req.Namespace, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) + result, err := s.handleInvokeWorker(r.Context(), req) if err != nil { w.WriteHeader(500) w.Write([]byte(err.Error())) @@ -284,6 +281,14 @@ func (s *server) invokeWorker(w http.ResponseWriter, r *http.Request) { } } +func (s *server) handleInvokeWorker(ctx context.Context, req types.InvokeWorkerHttpRequest) (io.ReadCloser, error) { + ctx, cc := context.WithTimeout(ctx, 5*time.Second) + defer cc() + + return s.environment.InvokeWorkerStream( + ctx, req.Namespace, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) +} + // ensureHijackable and terminateConnection are used in conjunction to close tcp connections // for requests where we've started copying the response stream into the HTTP response body // after submitting an HTTP 200 status code, but then encounter an error reading from the diff --git a/virtual/types/http.go b/virtual/types/http.go new file mode 100644 index 0000000..7cef431 --- /dev/null +++ b/virtual/types/http.go @@ -0,0 +1,39 @@ +package types + +type RegisterModuleHttpRequest struct { + Namespace string `json:"namespace"` + ModuleID string `json:"module_id"` + ModuleBytes []byte `json:"module_bytes"` +} + +type InvokeActorHttpRequest struct { + ServerID string `json:"server_id"` + Namespace string `json:"namespace"` + InvokeActorRequest + // Same data as Payload (in types.InvokeActorRequest), but different field so it doesn't + // have to be encoded as base64. + PayloadJSON interface{} `json:"payload_json"` +} + +type InvokeActorDirectHttpRequest struct { + VersionStamp int64 `json:"version_stamp"` + ServerID string `json:"server_id"` + ServerVersion int64 `json:"server_version"` + Namespace string `json:"namespace"` + ModuleID string `json:"module_id"` + ActorID string `json:"actor_id"` + Generation uint64 `json:"generation"` + Operation string `json:"operation"` + Payload []byte `json:"payload"` + CreateIfNotExist CreateIfNotExist `json:"create_if_not_exist"` +} + +type InvokeWorkerHttpRequest struct { + Namespace string `json:"namespace"` + // TODO: Allow ModuleID to be omitted if the caller provides a WASMExecutable field which contains the + // actual WASM program that should be executed. + ModuleID string `json:"module_id"` + Operation string `json:"operation"` + Payload []byte `json:"payload"` + CreateIfNotExist CreateIfNotExist `json:"create_if_not_exist"` +} diff --git a/virtual/websockets.go b/virtual/websockets.go new file mode 100644 index 0000000..586af3a --- /dev/null +++ b/virtual/websockets.go @@ -0,0 +1,167 @@ +package virtual + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/richardartoul/nola/virtual/registry" + "github.com/richardartoul/nola/virtual/types" + "golang.org/x/exp/slog" + "nhooyr.io/websocket" + "nhooyr.io/websocket/wsjson" +) + +var ErrUnknownMethod = errors.New("unknown method") + +func (s *server) wsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + c, err := websocket.Accept(w, r, nil) + if err != nil { + s.logger.Warn("failed to accept websocket connection", slog.Any("error", err)) + return + } + + var result any + + for { + var request jsonRpcRequest + err = wsjson.Read(ctx, c, &request) + if err != nil { + s.logger.Warn("failed to read websocket request", slog.Any("error", err)) + return + } + + switch request.Method { + case "register_module": + result, err = s.handleWsRegisterModule(ctx, request) + case "invoke": + result, err = s.handleWsInvoke(ctx, request) + case "invoke_direct": + result, err = s.handleWsInvokeDirect(ctx, request) + case "invoke_worker": + result, err = s.handleWsInvokeWorker(ctx, request) + default: + err = fmt.Errorf("%w: %s", ErrUnknownMethod, r.Method) + } + + response := jsonRpcResponse{VersionTag: request.VersionTag, ID: request.ID} + if err != nil { + response.Error.Code = websocket.StatusInternalError + response.Error.Message = err.Error() + } else { + response.Result = result + } + + if err := wsjson.Write(ctx, c, response); err != nil { + s.logger.Warn("failed to write websocket respose", slog.Any("error", err)) + return + } + } + +} + +func (s *server) handleWsRegisterModule(ctx context.Context, request jsonRpcRequest) (registry.RegisterModuleResult, error) { + var ( + params []types.RegisterModuleHttpRequest + msg types.RegisterModuleHttpRequest + ) + + if err := json.Unmarshal(request.Params, ¶ms); err != nil { + return registry.RegisterModuleResult{}, err + } + + if n := len(params); n != 1 { + return registry.RegisterModuleResult{}, fmt.Errorf("invalid number of params: expected 1 - received: %d", n) + } + + return s.handleRegisterModule(ctx, msg) + +} + +func (s *server) handleWsInvoke(ctx context.Context, request jsonRpcRequest) ([]byte, error) { + var ( + params []types.InvokeActorHttpRequest + msg types.InvokeActorHttpRequest + ) + + if err := json.Unmarshal(request.Params, ¶ms); err != nil { + return nil, err + } + + if n := len(params); n != 1 { + return nil, fmt.Errorf("invalid number of params: expected 1 - received: %d", n) + } + + result, err := s.handleInvoke(ctx, msg) + if err != nil { + return nil, err + } + + return io.ReadAll(result) +} + +func (s *server) handleWsInvokeDirect(ctx context.Context, request jsonRpcRequest) ([]byte, error) { + var ( + params []types.InvokeActorDirectHttpRequest + msg types.InvokeActorDirectHttpRequest + ) + + if err := json.Unmarshal(request.Params, ¶ms); err != nil { + return nil, err + } + + if n := len(params); n != 1 { + return nil, fmt.Errorf("invalid number of params: expected 1 - received: %d", n) + } + + result, err := s.handleInvokeDirect(ctx, msg) + if err != nil { + return nil, err + } + return io.ReadAll(result) +} + +func (s *server) handleWsInvokeWorker(ctx context.Context, request jsonRpcRequest) ([]byte, error) { + var ( + params []types.InvokeWorkerHttpRequest + msg types.InvokeWorkerHttpRequest + ) + + if err := json.Unmarshal(request.Params, ¶ms); err != nil { + return nil, err + } + + if n := len(params); n != 1 { + return nil, fmt.Errorf("invalid number of params: expected 1 - received: %d", n) + } + + result, err := s.handleInvokeWorker(ctx, msg) + if err != nil { + return nil, err + } + return io.ReadAll(result) +} + +type jsonRpcResponse struct { + VersionTag string `json:"jsonrpc"` + Result any `json:"result"` + Error *rpcError `json:"error"` + ID uint64 `json:"id"` +} + +type rpcError struct { + Code websocket.StatusCode `json:"code"` + Message string `json:"message"` +} + +type jsonRpcRequest struct { + VersionTag string `json:"jsonrpc"` + ID uint64 `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` +}