diff --git a/README.md b/README.md index 411c2138..12bfc3ed 100644 --- a/README.md +++ b/README.md @@ -52,14 +52,12 @@ TinyGo's official release of WASI target will come soon, and after that you coul just follow https://tinygo.org/getting-started/ to install the requirement on any platform. Stay tuned! -### compatible Envoy builds - -| proxy-wasm-go-sdk| proxy-wasm ABI version | envoyproxy/envoy-wasm| istio/proxyv2| -|:-------------:|:-------------:|:-------------:|:-------------:| -| main | 0.2.0| N/A | v1.17.x | -| v0.0.4 | 0.2.0| N/A | v1.17.x | -| v0.0.3 | 0.2.0| N/A | v1.17.x | -| v0.0.2 | 0.1.0|release/v1.15 | N/A | +### compatible ABI / Envoy builds + +| proxy-wasm-go-sdk| proxy-wasm ABI version |istio/proxyv2| +|:-------------:|:-------------:|:-------------:| +| main | 0.2.0| v1.17.x | +| v0.0.4 | 0.2.0| v1.17.x | ## run examples diff --git a/examples/helloworld/main.go b/examples/helloworld/main.go index 26f3b6f8..bda93005 100644 --- a/examples/helloworld/main.go +++ b/examples/helloworld/main.go @@ -26,7 +26,7 @@ func main() { type helloWorld struct { // you must embed the default context so that you need not to reimplement all the methods by yourself - proxywasm.DefaultContext + proxywasm.DefaultRootContext contextID uint32 } @@ -37,7 +37,7 @@ func newHelloWorld(contextID uint32) proxywasm.RootContext { // override func (ctx *helloWorld) OnVMStart(int) bool { proxywasm.LogInfo("proxy_on_vm_start from Go!") - if err := proxywasm.HostCallSetTickPeriodMilliSeconds(tickMilliseconds); err != nil { + if err := proxywasm.SetTickPeriodMilliSeconds(tickMilliseconds); err != nil { proxywasm.LogCriticalf("failed to set tick period: %v", err) } return true @@ -45,6 +45,6 @@ func (ctx *helloWorld) OnVMStart(int) bool { // override func (ctx *helloWorld) OnTick() { - t := proxywasm.HostCallGetCurrentTime() + t := proxywasm.GetCurrentTime() proxywasm.LogInfof("OnTick on %d, it's %d", ctx.contextID, t) } diff --git a/examples/helloworld/main_test.go b/examples/helloworld/main_test.go index 06432e02..f86ab389 100644 --- a/examples/helloworld/main_test.go +++ b/examples/helloworld/main_test.go @@ -3,6 +3,9 @@ package main import ( "strings" "testing" + "time" + + "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" @@ -11,23 +14,29 @@ import ( ) func TestHelloWorld_OnTick(t *testing.T) { - ctx := newHelloWorld(100) - host, done := proxytest.NewRootFilterHost(ctx, nil, nil) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation - ctx.OnTick() + opt := proxytest.NewEmulatorOption(). + WithNewRootContext(newHelloWorld) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation + + host.StartVM() // call OnVMStart + + time.Sleep(time.Duration(tickMilliseconds) * 4 * time.Millisecond) logs := host.GetLogs(types.LogLevelInfo) + require.Greater(t, len(logs), 0) msg := logs[len(logs)-1] assert.True(t, strings.Contains(msg, "OnTick on")) } func TestHelloWorld_OnVMStart(t *testing.T) { - ctx := newHelloWorld(0) - host, done := proxytest.NewRootFilterHost(ctx, nil, nil) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewRootContext(newHelloWorld) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - host.StartVM() + host.StartVM() // call OnVMStart logs := host.GetLogs(types.LogLevelInfo) msg := logs[len(logs)-1] diff --git a/examples/http_auth_random/main.go b/examples/http_auth_random/main.go index 7af6da6a..e2c9fce7 100644 --- a/examples/http_auth_random/main.go +++ b/examples/http_auth_random/main.go @@ -29,7 +29,7 @@ func main() { type httpAuthRandom struct { // you must embed the default context so that you need not to reimplement all the methods by yourself - proxywasm.DefaultContext + proxywasm.DefaultHttpContext contextID uint32 } @@ -39,7 +39,7 @@ func newContext(contextID uint32) proxywasm.HttpContext { // override default func (ctx *httpAuthRandom) OnHttpRequestHeaders(int, bool) types.Action { - hs, err := proxywasm.HostCallGetHttpRequestHeaders() + hs, err := proxywasm.GetHttpRequestHeaders() if err != nil { proxywasm.LogCriticalf("failed to get request headers: %v", err) return types.ActionContinue @@ -48,19 +48,18 @@ func (ctx *httpAuthRandom) OnHttpRequestHeaders(int, bool) types.Action { proxywasm.LogInfof("request header: %s: %s", h[0], h[1]) } - if _, err := proxywasm.HostCallDispatchHttpCall( - clusterName, hs, "", [][2]string{}, 50000); err != nil { + if _, err := proxywasm.DispatchHttpCall(clusterName, hs, "", [][2]string{}, + 50000, httpCallResponseCallback); err != nil { proxywasm.LogCriticalf("dipatch httpcall failed: %v", err) + return types.ActionContinue } proxywasm.LogInfof("http call dispatched to %s", clusterName) - return types.ActionPause } -// override default -func (ctx *httpAuthRandom) OnHttpCallResponse(_ int, bodySize int, _ int) { - hs, err := proxywasm.HostCallGetHttpCallResponseHeaders() +func httpCallResponseCallback(_ int, bodySize int, _ int) { + hs, err := proxywasm.GetHttpCallResponseHeaders() if err != nil { proxywasm.LogCriticalf("failed to get response body: %v", err) @@ -71,34 +70,29 @@ func (ctx *httpAuthRandom) OnHttpCallResponse(_ int, bodySize int, _ int) { proxywasm.LogInfof("response header from %s: %s: %s", clusterName, h[0], h[1]) } - b, err := proxywasm.HostCallGetHttpCallResponseBody(0, bodySize) + b, err := proxywasm.GetHttpCallResponseBody(0, bodySize) if err != nil { proxywasm.LogCriticalf("failed to get response body: %v", err) - proxywasm.HostCallResumeHttpRequest() + proxywasm.ResumeHttpRequest() return } s := fnv.New32a() if _, err := s.Write(b); err != nil { proxywasm.LogCriticalf("failed to calculate hash: %v", err) - proxywasm.HostCallResumeHttpRequest() + proxywasm.ResumeHttpRequest() return } if s.Sum32()%2 == 0 { proxywasm.LogInfo("access granted") - proxywasm.HostCallResumeHttpRequest() + proxywasm.ResumeHttpRequest() return } msg := "access forbidden" proxywasm.LogInfo(msg) - proxywasm.HostCallSendHttpResponse(403, [][2]string{ + proxywasm.SendHttpResponse(403, [][2]string{ {"powered-by", "proxy-wasm-go-sdk!!"}, }, msg) } - -// override default -func (ctx *httpAuthRandom) OnLog() { - proxywasm.LogInfof("%d finished", ctx.contextID) -} diff --git a/examples/http_auth_random/main_test.go b/examples/http_auth_random/main_test.go index 3a8b8563..5e821cb4 100644 --- a/examples/http_auth_random/main_test.go +++ b/examples/http_auth_random/main_test.go @@ -11,14 +11,20 @@ import ( ) func TestHttpAuthRandom_OnHttpRequestHeaders(t *testing.T) { - host, done := proxytest.NewHttpFilterHost(newContext) - defer done() + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() - id := host.InitContext() - host.PutRequestHeaders(id, [][2]string{{"key", "value"}}) // OnHttpRequestHeaders called + contextID := host.HttpFilterInitContext() + host.HttpFilterPutRequestHeaders(contextID, [][2]string{{"key", "value"}}) // OnHttpRequestHeaders called - require.True(t, host.IsDispatchCalled(id)) // check if http call is dispatched - require.Equal(t, types.ActionPause, host.GetCurrentAction(id)) // check if the current action is pause + attrs := host.GetCalloutAttributesFromContext(contextID) + require.Equal(t, len(attrs), 1) // verify DispatchHttpCall is called + + require.Equal(t, "httpbin", attrs[0].Upstream) + require.Equal(t, types.ActionPause, + host.HttpFilterGetCurrentStreamAction(contextID)) // check if the current action is pause logs := host.GetLogs(types.LogLevelInfo) require.GreaterOrEqual(t, len(logs), 2) @@ -28,8 +34,10 @@ func TestHttpAuthRandom_OnHttpRequestHeaders(t *testing.T) { } func TestHttpAuthRandom_OnHttpCallResponse(t *testing.T) { - host, done := proxytest.NewHttpFilterHost(newContext) - defer done() + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // http://httpbin.org/uuid headers := [][2]string{ @@ -40,20 +48,30 @@ func TestHttpAuthRandom_OnHttpCallResponse(t *testing.T) { } // access granted body - id := host.InitContext() + contextID := host.HttpFilterInitContext() + host.HttpFilterPutRequestHeaders(contextID, nil) // OnHttpRequestHeaders called + body := []byte(`{"uuid": "7b10a67a-1c67-4199-835b-cbefcd4a63d4"}`) - host.PutCalloutResponse(id, headers, nil, body) - assert.Nil(t, host.GetSentLocalResponse(id)) + attrs := host.GetCalloutAttributesFromContext(contextID) + require.Equal(t, len(attrs), 1) // verify DispatchHttpCall is called + + host.PutCalloutResponse(attrs[0].CalloutID, headers, nil, body) + assert.Nil(t, host.HttpFilterGetSentLocalResponse(contextID)) logs := host.GetLogs(types.LogLevelInfo) require.Greater(t, len(logs), 1) assert.Equal(t, "access granted", logs[len(logs)-1]) // access denied body - id = host.InitContext() + contextID = host.HttpFilterInitContext() + host.HttpFilterPutRequestHeaders(contextID, nil) // OnHttpRequestHeaders called + body = []byte(`{"uuid": "aaaaaaaa-1c67-4199-835b-cbefcd4a63d4"}`) - host.PutCalloutResponse(id, headers, nil, body) - localResponse := host.GetSentLocalResponse(id) // check local responses + attrs = host.GetCalloutAttributesFromContext(contextID) + require.Equal(t, len(attrs), 1) // verify DispatchHttpCall is called + + host.PutCalloutResponse(attrs[0].CalloutID, headers, nil, body) + localResponse := host.HttpFilterGetSentLocalResponse(contextID) // check local responses assert.NotNil(t, localResponse) logs = host.GetLogs(types.LogLevelInfo) assert.Equal(t, "access forbidden", logs[len(logs)-1]) diff --git a/examples/http_headers/main.go b/examples/http_headers/main.go index 394ec31e..77fd8025 100644 --- a/examples/http_headers/main.go +++ b/examples/http_headers/main.go @@ -25,7 +25,7 @@ func main() { type httpHeaders struct { // you must embed the default context so that you need not to reimplement all the methods by yourself - proxywasm.DefaultContext + proxywasm.DefaultHttpContext contextID uint32 } @@ -35,7 +35,7 @@ func newContext(contextID uint32) proxywasm.HttpContext { // override func (ctx *httpHeaders) OnHttpRequestHeaders(int, bool) types.Action { - hs, err := proxywasm.HostCallGetHttpRequestHeaders() + hs, err := proxywasm.GetHttpRequestHeaders() if err != nil { proxywasm.LogCriticalf("failed to get request headers: %v", err) } @@ -48,7 +48,7 @@ func (ctx *httpHeaders) OnHttpRequestHeaders(int, bool) types.Action { // override func (ctx *httpHeaders) OnHttpResponseHeaders(int, bool) types.Action { - hs, err := proxywasm.HostCallGetHttpResponseHeaders() + hs, err := proxywasm.GetHttpResponseHeaders() if err != nil { proxywasm.LogCriticalf("failed to get request headers: %v", err) } @@ -60,6 +60,6 @@ func (ctx *httpHeaders) OnHttpResponseHeaders(int, bool) types.Action { } // override -func (ctx *httpHeaders) OnLog() { +func (ctx *httpHeaders) OnHttpStreamDone() { proxywasm.LogInfof("%d finished", ctx.contextID) } diff --git a/examples/http_headers/main_test.go b/examples/http_headers/main_test.go index c67f9f93..7143cdcc 100644 --- a/examples/http_headers/main_test.go +++ b/examples/http_headers/main_test.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -11,31 +12,40 @@ import ( ) func TestHttpHeaders_OnHttpRequestHeaders(t *testing.T) { - host, done := proxytest.NewHttpFilterHost(newContext) - defer done() - id := host.InitContext() + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() + id := host.HttpFilterInitContext() hs := [][2]string{{"key1", "value1"}, {"key2", "value2"}} - host.PutRequestHeaders(id, hs) // call OnHttpRequestHeaders + host.HttpFilterPutRequestHeaders(id, hs) // call OnHttpRequestHeaders + + host.HttpFilterCompleteHttpStream(id) logs := host.GetLogs(types.LogLevelInfo) require.Greater(t, len(logs), 1) - assert.Equal(t, "request header --> key2: value2", logs[len(logs)-1]) - assert.Equal(t, "request header --> key1: value1", logs[len(logs)-2]) + assert.Equal(t, fmt.Sprintf("%d finished", id), logs[len(logs)-1]) + assert.Equal(t, "request header --> key2: value2", logs[len(logs)-2]) + assert.Equal(t, "request header --> key1: value1", logs[len(logs)-3]) } func TestHttpHeaders_OnHttpResponseHeaders(t *testing.T) { - host, done := proxytest.NewHttpFilterHost(newContext) - defer done() - id := host.InitContext() + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() + id := host.HttpFilterInitContext() hs := [][2]string{{"key1", "value1"}, {"key2", "value2"}} - host.PutResponseHeaders(id, hs) // call OnHttpResponseHeaders + host.HttpFilterPutResponseHeaders(id, hs) // call OnHttpRequestHeaders + host.HttpFilterCompleteHttpStream(id) // call OnHttpStreamDone logs := host.GetLogs(types.LogLevelInfo) require.Greater(t, len(logs), 1) - assert.Equal(t, "response header <-- key2: value2", logs[len(logs)-1]) - assert.Equal(t, "response header <-- key1: value1", logs[len(logs)-2]) + assert.Equal(t, fmt.Sprintf("%d finished", id), logs[len(logs)-1]) + assert.Equal(t, "response header <-- key2: value2", logs[len(logs)-2]) + assert.Equal(t, "response header <-- key1: value1", logs[len(logs)-3]) } diff --git a/examples/metrics/main.go b/examples/metrics/main.go index 571c4095..ce3769fc 100644 --- a/examples/metrics/main.go +++ b/examples/metrics/main.go @@ -20,38 +20,44 @@ import ( ) func main() { - proxywasm.SetNewRootContext(func(uint32) proxywasm.RootContext { return metric{} }) - proxywasm.SetNewHttpContext(func(uint32) proxywasm.HttpContext { return metric{} }) + proxywasm.SetNewRootContext(newRootContext) + proxywasm.SetNewHttpContext(newHttpContext) } var counter proxywasm.MetricCounter const metricsName = "proxy_wasm_go.request_counter" -type metric struct{ proxywasm.DefaultContext } +type metricRootContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultRootContext +} + +func newRootContext(uint32) proxywasm.RootContext { + return &metricRootContext{} +} // override -func (ctx metric) OnVMStart(int) bool { - ct, err := proxywasm.DefineCounterMetric(metricsName) - if err != nil { - proxywasm.LogCriticalf("error defining metrics: %v", err) - } - counter = ct +func (ctx *metricRootContext) OnVMStart(int) bool { + counter = proxywasm.DefineCounterMetric(metricsName) return true } -// override -func (ctx metric) OnHttpRequestHeaders(int, bool) types.Action { - prev, err := counter.Get() - if err != nil { - proxywasm.LogCriticalf("error retrieving previous metric: %v", err) - } +type metricHttpContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultHttpContext +} + +func newHttpContext(uint32) proxywasm.HttpContext { + return &metricHttpContext{} +} +// override +func (ctx *metricHttpContext) OnHttpRequestHeaders(int, bool) types.Action { + prev := counter.Get() proxywasm.LogInfof("previous value of %s: %d", metricsName, prev) - if err := counter.Increment(1); err != nil { - proxywasm.LogCriticalf("error incrementing metrics %v", err) - } + counter.Increment(1) proxywasm.LogInfo("incremented") return types.ActionContinue } diff --git a/examples/metrics/main_test.go b/examples/metrics/main_test.go index 222e77a2..2a62861d 100644 --- a/examples/metrics/main_test.go +++ b/examples/metrics/main_test.go @@ -11,16 +11,18 @@ import ( ) func TestMetric(t *testing.T) { + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newHttpContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - ctx := metric{} - host, done := proxytest.NewRootFilterHost(ctx, nil, nil) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation - - host.StartVM() // define metric + host.StartVM() // call OnVMStart: define metric + contextID := host.HttpFilterInitContext() exp := uint64(3) for i := uint64(0); i < exp; i++ { - ctx.OnHttpRequestHeaders(0, false) // increment + host.HttpFilterPutRequestHeaders(contextID, nil) } logs := host.GetLogs(types.LogLevelInfo) @@ -28,7 +30,6 @@ func TestMetric(t *testing.T) { assert.Equal(t, "incremented", logs[len(logs)-1]) - value, err := counter.Get() - require.NoError(t, err) + value := counter.Get() assert.Equal(t, uint64(3), value) } diff --git a/examples/network/main.go b/examples/network/main.go index e052c974..fd8cc131 100644 --- a/examples/network/main.go +++ b/examples/network/main.go @@ -25,58 +25,72 @@ var ( ) func main() { - proxywasm.SetNewStreamContext(func(contextID uint32) proxywasm.StreamContext { return context{} }) - proxywasm.SetNewRootContext(func(contextID uint32) proxywasm.RootContext { return context{} }) + proxywasm.SetNewRootContext(newRootContext) + proxywasm.SetNewStreamContext(newNetworkContext) } -type context struct{ proxywasm.DefaultContext } +type rootContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultRootContext +} -func (ctx context) OnVMStart(int) bool { - var err error - counter, err = proxywasm.DefineCounterMetric(connectionCounterName) - if err != nil { - proxywasm.LogCriticalf("failed to initialize connection counter: %v", err) - } +func newRootContext(uint32) proxywasm.RootContext { + return &rootContext{} +} + +func (ctx *rootContext) OnVMStart(int) bool { + counter = proxywasm.DefineCounterMetric(connectionCounterName) return true } -func (ctx context) OnNewConnection() types.Action { +type networkContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultStreamContext +} + +func newNetworkContext(uint32) proxywasm.StreamContext { + return &networkContext{} +} + +func (ctx *networkContext) OnNewConnection() types.Action { proxywasm.LogInfo("new connection!") return types.ActionContinue } -func (ctx context) OnDownstreamData(dataSize int, _ bool) types.Action { +func (ctx *networkContext) OnDownstreamData(dataSize int, _ bool) types.Action { if dataSize == 0 { return types.ActionContinue } - data, err := proxywasm.HostCallGetDownStreamData(0, dataSize) + data, err := proxywasm.GetDownStreamData(0, dataSize) if err != nil && err != types.ErrorStatusNotFound { - proxywasm.LogCritical(err.Error()) + proxywasm.LogCriticalf("failed to get downstream data: %v", err) + return types.ActionContinue } proxywasm.LogInfof(">>>>>> downstream data received >>>>>>\n%s", string(data)) return types.ActionContinue } -func (ctx context) OnDownstreamClose(types.PeerType) { +func (ctx *networkContext) OnDownstreamClose(types.PeerType) { proxywasm.LogInfo("downstream connection close!") return } -func (ctx context) OnUpstreamData(dataSize int, _ bool) types.Action { +func (ctx *networkContext) OnUpstreamData(dataSize int, _ bool) types.Action { if dataSize == 0 { return types.ActionContinue } - ret, err := proxywasm.HostCallGetProperty([]string{"upstream", "address"}) + ret, err := proxywasm.GetProperty([]string{"upstream", "address"}) if err != nil { - proxywasm.LogCritical(err.Error()) + proxywasm.LogCriticalf("failed to get downstream data: %v", err) + return types.ActionContinue } proxywasm.LogInfof("remote address: %s", string(ret)) - data, err := proxywasm.HostCallGetUpstreamData(0, dataSize) + data, err := proxywasm.GetUpstreamData(0, dataSize) if err != nil && err != types.ErrorStatusNotFound { proxywasm.LogCritical(err.Error()) } @@ -85,11 +99,7 @@ func (ctx context) OnUpstreamData(dataSize int, _ bool) types.Action { return types.ActionContinue } -func (ctx context) OnDone() bool { - err := counter.Increment(1) - if err != nil { - proxywasm.LogCriticalf("failed to increment connection counter: %v", err) - } +func (ctx *networkContext) OnStreamDone() { + counter.Increment(1) proxywasm.LogInfo("connection complete!") - return true } diff --git a/examples/network/main_test.go b/examples/network/main_test.go index 5a8a0e55..00226b94 100644 --- a/examples/network/main_test.go +++ b/examples/network/main_test.go @@ -21,30 +21,33 @@ import ( "github.com/stretchr/testify/require" "github.com/tetratelabs/proxy-wasm-go-sdk/proxytest" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -func newStreamContext(uint32) proxywasm.StreamContext { - return context{} -} - func TestNetwork_OnNewConnection(t *testing.T) { - host, done := proxytest.NewNetworkFilterHost(newStreamContext) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewStreamContext(newNetworkContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation + + host.StartVM() // call OnVMStart: init metric - _ = host.InitConnection() // OnNewConnection is called + _ = host.NetworkFilterInitConnection() // OnNewConnection is called logs := host.GetLogs(types.LogLevelInfo) // retrieve logs emitted to Envoy assert.Equal(t, logs[0], "new connection!") } func TestNetwork_OnDownstreamClose(t *testing.T) { - host, done := proxytest.NewNetworkFilterHost(newStreamContext) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewStreamContext(newNetworkContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - contextID := host.InitConnection() // OnNewConnection is called - host.CloseDownstreamConnection(contextID) // OnDownstreamClose is called + contextID := host.NetworkFilterInitConnection() // OnNewConnection is called + host.NetworkFilterCloseDownstreamConnection(contextID) // OnDownstreamClose is called logs := host.GetLogs(types.LogLevelInfo) // retrieve logs emitted to Envoy require.Len(t, logs, 2) @@ -52,47 +55,55 @@ func TestNetwork_OnDownstreamClose(t *testing.T) { } func TestNetwork_OnDownstreamData(t *testing.T) { - host, done := proxytest.NewNetworkFilterHost(newStreamContext) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewStreamContext(newNetworkContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - contextID := host.InitConnection() // OnNewConnection is called + contextID := host.NetworkFilterInitConnection() // OnNewConnection is called msg := "this is downstream data" data := []byte(msg) - host.PutDownstreamData(contextID, data) // OnDownstreamData is called + host.NetworkFilterPutDownstreamData(contextID, data) // OnDownstreamData is called logs := host.GetLogs(types.LogLevelInfo) // retrieve logs emitted to Envoy assert.Equal(t, ">>>>>> downstream data received >>>>>>\n"+msg, logs[len(logs)-1]) } func TestNetwork_OnUpstreamData(t *testing.T) { - host, done := proxytest.NewNetworkFilterHost(newStreamContext) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewStreamContext(newNetworkContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - contextID := host.InitConnection() // OnNewConnection is called + contextID := host.NetworkFilterInitConnection() // OnNewConnection is called msg := "this is upstream data" data := []byte(msg) - host.PutUpstreamData(contextID, data) // OnUpstreamData is called + host.NetworkFilterPutUpstreamData(contextID, data) // OnUpstreamData is called logs := host.GetLogs(types.LogLevelInfo) // retrieve logs emitted to Envoy assert.Equal(t, "<<<<<< upstream data received <<<<<<\n"+msg, logs[len(logs)-1]) } func TestNetwork_counter(t *testing.T) { - host, done := proxytest.NewNetworkFilterHost(newStreamContext) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewStreamContext(newNetworkContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - context{}.OnVMStart(0) // init metric + host.StartVM() // call OnVMStart: init metric - contextID := host.InitConnection() - host.CompleteConnection(contextID) // call OnDone on contextID -> increment the connection counter + contextID := host.NetworkFilterInitConnection() + host.NetworkFilterCompleteConnection(contextID) // call OnStreamDone on contextID -> increment the connection counter logs := host.GetLogs(types.LogLevelInfo) require.Greater(t, len(logs), 0) assert.Equal(t, "connection complete!", logs[len(logs)-1]) - actual, err := counter.Get() - require.NoError(t, err) + actual := counter.Get() assert.Equal(t, uint64(1), actual) } diff --git a/examples/shared_data/main.go b/examples/shared_data/main.go index 63a98641..f4116f81 100644 --- a/examples/shared_data/main.go +++ b/examples/shared_data/main.go @@ -20,32 +20,50 @@ import ( ) func main() { - proxywasm.SetNewRootContext(func(uint32) proxywasm.RootContext { return data{} }) - proxywasm.SetNewHttpContext(func(uint32) proxywasm.HttpContext { return data{} }) + proxywasm.SetNewRootContext(newRootContext) + proxywasm.SetNewHttpContext(newHttpContext) } -type data struct{ proxywasm.DefaultContext } +type ( + sharedDataRootContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultRootContext + } + + sharedDataHttpContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultHttpContext + } +) + +func newRootContext(uint32) proxywasm.RootContext { + return &sharedDataRootContext{} +} + +func newHttpContext(uint32) proxywasm.HttpContext { + return &sharedDataHttpContext{} +} const sharedDataKey = "shared_data_key" // override -func (ctx data) OnVMStart(int) bool { - if err := proxywasm.HostCallSetSharedData(sharedDataKey, []byte{0}, 0); err != nil { +func (ctx *sharedDataRootContext) OnVMStart(int) bool { + if err := proxywasm.SetSharedData(sharedDataKey, []byte{0}, 0); err != nil { proxywasm.LogWarnf("error setting shared data on OnVMStart: %v", err) } return true } // override -func (ctx data) OnHttpRequestHeaders(int, bool) types.Action { - value, cas, err := proxywasm.HostCallGetSharedData(sharedDataKey) +func (ctx *sharedDataHttpContext) OnHttpRequestHeaders(int, bool) types.Action { + value, cas, err := proxywasm.GetSharedData(sharedDataKey) if err != nil { proxywasm.LogWarnf("error getting shared data on OnHttpRequestHeaders: %v", err) return types.ActionContinue } value[0]++ - if err := proxywasm.HostCallSetSharedData(sharedDataKey, value, cas); err != nil { + if err := proxywasm.SetSharedData(sharedDataKey, value, cas); err != nil { proxywasm.LogWarnf("error setting shared data on OnHttpRequestHeaders: %v", err) return types.ActionContinue } diff --git a/examples/shared_data/main_test.go b/examples/shared_data/main_test.go index 5113057b..ac76a8d8 100644 --- a/examples/shared_data/main_test.go +++ b/examples/shared_data/main_test.go @@ -25,20 +25,22 @@ import ( ) func TestData(t *testing.T) { - ctx := data{} - host, done := proxytest.NewRootFilterHost(ctx, nil, nil) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newHttpContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation host.StartVM() // set initial value + contextID := host.HttpFilterInitContext() + host.HttpFilterPutRequestHeaders(contextID, nil) // OnHttpRequestHeaders is called - ctx.OnHttpRequestHeaders(0, false) // update logs := host.GetLogs(types.LogLevelInfo) require.Greater(t, len(logs), 0) assert.Equal(t, "shared value: 1", logs[len(logs)-1]) - - ctx.OnHttpRequestHeaders(0, false) // update - ctx.OnHttpRequestHeaders(0, false) // update + host.HttpFilterPutRequestHeaders(contextID, nil) // OnHttpRequestHeaders is called + host.HttpFilterPutRequestHeaders(contextID, nil) // OnHttpRequestHeaders is called logs = host.GetLogs(types.LogLevelInfo) assert.Equal(t, "shared value: 3", logs[len(logs)-1]) diff --git a/examples/shared_queue/main.go b/examples/shared_queue/main.go index 98e1f00b..a3d2ade7 100644 --- a/examples/shared_queue/main.go +++ b/examples/shared_queue/main.go @@ -19,30 +19,37 @@ import ( "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -func main() { - proxywasm.SetNewRootContext(func(uint32) proxywasm.RootContext { return queue{} }) - proxywasm.SetNewHttpContext(func(uint32) proxywasm.HttpContext { return queue{} }) -} - -type queue struct{ proxywasm.DefaultContext } - const ( queueName = "proxy_wasm_go.queue" tickMilliseconds uint32 = 100 ) +func main() { + proxywasm.SetNewRootContext(newRootContext) + proxywasm.SetNewHttpContext(newHttpContext) +} + +type queueRootContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultRootContext +} + +func newRootContext(uint32) proxywasm.RootContext { + return &queueRootContext{} +} + var queueID uint32 // override -func (ctx queue) OnVMStart(int) bool { - qID, err := proxywasm.HostCallRegisterSharedQueue(queueName) +func (ctx *queueRootContext) OnVMStart(int) bool { + qID, err := proxywasm.RegisterSharedQueue(queueName) if err != nil { panic(err.Error()) } queueID = qID proxywasm.LogInfof("queue registered, name: %s, id: %d", queueName, qID) - if err := proxywasm.HostCallSetTickPeriodMilliSeconds(tickMilliseconds); err != nil { + if err := proxywasm.SetTickPeriodMilliSeconds(tickMilliseconds); err != nil { proxywasm.LogCriticalf("failed to set tick period: %v", err) } proxywasm.LogInfof("set tick period milliseconds: %d", tickMilliseconds) @@ -50,18 +57,8 @@ func (ctx queue) OnVMStart(int) bool { } // override -func (ctx queue) OnHttpRequestHeaders(int, bool) types.Action { - for _, msg := range []string{"hello", "world", "hello", "proxy-wasm"} { - if err := proxywasm.HostCallEnqueueSharedQueue(queueID, []byte(msg)); err != nil { - proxywasm.LogCriticalf("error queueing: %v", err) - } - } - return types.ActionContinue -} - -// override -func (ctx queue) OnTick() { - data, err := proxywasm.HostCallDequeueSharedQueue(queueID) +func (ctx *queueRootContext) OnTick() { + data, err := proxywasm.DequeueSharedQueue(queueID) switch err { case types.ErrorStatusEmpty: return @@ -71,3 +68,22 @@ func (ctx queue) OnTick() { proxywasm.LogCriticalf("error retrieving data from queue %d: %v", queueID, err) } } + +type queueHttpContext struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultHttpContext +} + +func newHttpContext(uint32) proxywasm.HttpContext { + return &queueHttpContext{} +} + +// override +func (ctx *queueHttpContext) OnHttpRequestHeaders(int, bool) types.Action { + for _, msg := range []string{"hello", "world", "hello", "proxy-wasm"} { + if err := proxywasm.EnqueueSharedQueue(queueID, []byte(msg)); err != nil { + proxywasm.LogCriticalf("error queueing: %v", err) + } + } + return types.ActionContinue +} diff --git a/examples/shared_queue/main_test.go b/examples/shared_queue/main_test.go index 2a77f69c..34fe3c62 100644 --- a/examples/shared_queue/main_test.go +++ b/examples/shared_queue/main_test.go @@ -17,6 +17,7 @@ package main import ( "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -26,24 +27,27 @@ import ( ) func TestQueue(t *testing.T) { - ctx := queue{} - host, done := proxytest.NewRootFilterHost(ctx, nil, nil) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithNewHttpContext(newHttpContext). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation host.StartVM() // register the queue,set tick period logs := host.GetLogs(types.LogLevelInfo) + require.Greater(t, len(logs), 0) assert.Equal(t, logs[0], fmt.Sprintf("queue registered, name: %s, id: %d", queueName, queueID)) assert.Equal(t, tickMilliseconds, host.GetTickPeriod()) - ctx.OnHttpRequestHeaders(0, false) // call enqueue + contextID := host.HttpFilterInitContext() + host.HttpFilterPutRequestHeaders(contextID, nil) // call enqueue assert.Equal(t, 4, host.GetQueueSize(queueID)) - for i := 0; i < 4; i++ { - ctx.OnTick() // dequeue - } + time.Sleep(time.Duration(tickMilliseconds*5) * time.Millisecond) + logs = host.GetLogs(types.LogLevelInfo) - require.Greater(t, len(logs), 4) + require.Greater(t, len(logs), 5) assert.Equal(t, "dequeued data: hello", logs[len(logs)-4]) assert.Equal(t, "dequeued data: world", logs[len(logs)-3]) diff --git a/examples/vm_plugin_configuration/main.go b/examples/vm_plugin_configuration/main.go index a99db3da..63941f66 100644 --- a/examples/vm_plugin_configuration/main.go +++ b/examples/vm_plugin_configuration/main.go @@ -19,14 +19,21 @@ import ( ) func main() { - proxywasm.SetNewRootContext(func(uint32) proxywasm.RootContext { return context{} }) + proxywasm.SetNewRootContext(newRootContext) } -type context struct{ proxywasm.DefaultContext } +type context struct { + // you must embed the default context so that you need not to reimplement all the methods by yourself + proxywasm.DefaultRootContext +} + +func newRootContext(contextID uint32) proxywasm.RootContext { + return &context{} +} // override func (ctx context) OnVMStart(vmConfigurationSize int) bool { - data, err := proxywasm.HostCallGetVMConfiguration(vmConfigurationSize) + data, err := proxywasm.GetVMConfiguration(vmConfigurationSize) if err != nil { proxywasm.LogCriticalf("error reading vm configuration: %v", err) } @@ -35,8 +42,8 @@ func (ctx context) OnVMStart(vmConfigurationSize int) bool { return true } -func (ctx context) OnConfigure(pluginConfigurationSize int) bool { - data, err := proxywasm.HostCallGetPluginConfiguration(pluginConfigurationSize) +func (ctx context) OnPluginStart(pluginConfigurationSize int) bool { + data, err := proxywasm.GetPluginConfiguration(pluginConfigurationSize) if err != nil { proxywasm.LogCriticalf("error reading plugin configuration: %v", err) } diff --git a/examples/vm_plugin_configuration/main_test.go b/examples/vm_plugin_configuration/main_test.go index b0263226..4193c001 100644 --- a/examples/vm_plugin_configuration/main_test.go +++ b/examples/vm_plugin_configuration/main_test.go @@ -18,34 +18,43 @@ import ( "strings" "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" "github.com/tetratelabs/proxy-wasm-go-sdk/proxytest" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -func TestContext_OnConfigure(t *testing.T) { +func TestContext_OnPluginStart(t *testing.T) { pluginConfigData := `{"name": "tinygo plugin configuration"}` - ctx := context{} - host, done := proxytest.NewRootFilterHost(ctx, []byte(pluginConfigData), nil) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation - host.ConfigurePlugin() // invoke OnConfigure + opt := proxytest.NewEmulatorOption(). + WithPluginConfiguration([]byte(pluginConfigData)). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the emulation lock so that other test cases can insert their own host emulation + + host.StartPlugin() // invoke OnPluginStart logs := host.GetLogs(types.LogLevelInfo) + require.Greater(t, len(logs), 0) msg := logs[len(logs)-1] assert.True(t, strings.Contains(msg, pluginConfigData)) } func TestContext_OnVMStart(t *testing.T) { vmConfigData := `{"name": "tinygo vm configuration"}` - ctx := context{} - host, done := proxytest.NewRootFilterHost(ctx, nil, []byte(vmConfigData)) - defer done() // release the host emulation lock so that other test cases can insert their own host emulation + opt := proxytest.NewEmulatorOption(). + WithVMConfiguration([]byte(vmConfigData)). + WithNewRootContext(newRootContext) + host := proxytest.NewHostEmulator(opt) + defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation - host.StartVM() // invoke OnConfigure + host.StartVM() // invoke OnVMStart logs := host.GetLogs(types.LogLevelInfo) + require.Greater(t, len(logs), 0) msg := logs[len(logs)-1] assert.True(t, strings.Contains(msg, vmConfigData)) } diff --git a/go.sum b/go.sum index 56d62e7c..114c3d0a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/mathetake/proxy-wasm-go v0.0.4 h1:bWU1/hqnUpE7RpvhmmvNSelOcdRKVgawVwqFHA1H61E= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/proxytest/base.go b/proxytest/base.go deleted file mode 100644 index cbde71e7..00000000 --- a/proxytest/base.go +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright 2020 Tetrate -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxytest - -import ( - "log" - "sync" - - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" -) - -var hostMux = sync.Mutex{} - -type baseHost struct { - rawhostcall.DefaultProxyWAMSHost - currentContextID uint32 - - logs [types.LogLevelMax][]string - tickPeriod uint32 - - queues map[uint32][][]byte - queueNameID map[string]uint32 - - sharedDataKVS map[string]*sharedData - - metricIDToValue map[uint32]uint64 - metricIDToType map[uint32]types.MetricType - metricNameToID map[string]uint32 - - calloutCallbackCaller func(contextID uint32, numHeaders, bodySize, numTrailers int) - calloutResponse map[uint32]struct { - headers, trailers [][2]string - body []byte - } - callouts map[uint32]struct{} -} - -type sharedData struct { - data []byte - cas uint32 -} - -func newBaseHost(f func(contextID uint32, numHeaders, bodySize, numTrailers int)) *baseHost { - return &baseHost{ - queues: map[uint32][][]byte{}, - queueNameID: map[string]uint32{}, - sharedDataKVS: map[string]*sharedData{}, - metricIDToValue: map[uint32]uint64{}, - metricIDToType: map[uint32]types.MetricType{}, - metricNameToID: map[string]uint32{}, - calloutCallbackCaller: f, - calloutResponse: map[uint32]struct { - headers, trailers [][2]string - body []byte - }{}, - callouts: map[uint32]struct{}{}, - } -} - -func (b *baseHost) ProxyLog(logLevel types.LogLevel, messageData *byte, messageSize int) types.Status { - str := proxywasm.RawBytePtrToString(messageData, messageSize) - - log.Printf("proxy_log: %s", str) - // TODO: exit if loglevel == fatal? - - b.logs[logLevel] = append(b.logs[logLevel], str) - return types.StatusOK -} - -func (b *baseHost) GetLogs(level types.LogLevel) []string { - if level >= types.LogLevelMax { - log.Fatalf("invalid log level: %d", level) - } - return b.logs[level] -} - -func (b *baseHost) ProxySetTickPeriodMilliseconds(period uint32) types.Status { - b.tickPeriod = period - return types.StatusOK -} - -func (b *baseHost) GetTickPeriod() uint32 { - return b.tickPeriod -} - -func (b *baseHost) ProxyRegisterSharedQueue(nameData *byte, nameSize int, returnID *uint32) types.Status { - name := proxywasm.RawBytePtrToString(nameData, nameSize) - if id, ok := b.queueNameID[name]; ok { - *returnID = id - return types.StatusOK - } - - id := uint32(len(b.queues)) - b.queues[id] = [][]byte{} - b.queueNameID[name] = id - *returnID = id - return types.StatusOK -} - -func (b *baseHost) ProxyDequeueSharedQueue(queueID uint32, returnValueData **byte, returnValueSize *int) types.Status { - queue, ok := b.queues[queueID] - if !ok { - log.Printf("queue %d is not found", queueID) - return types.StatusNotFound - } else if len(queue) == 0 { - log.Printf("queue %d is empty", queueID) - return types.StatusEmpty - } - - data := queue[0] - *returnValueData = &data[0] - *returnValueSize = len(data) - b.queues[queueID] = queue[1:] - return types.StatusOK -} - -func (b *baseHost) ProxyEnqueueSharedQueue(queueID uint32, valueData *byte, valueSize int) types.Status { - queue, ok := b.queues[queueID] - if !ok { - log.Printf("queue %d is not found", queueID) - return types.StatusNotFound - } - - b.queues[queueID] = append(queue, proxywasm.RawBytePtrToByteSlice(valueData, valueSize)) - - // TODO: should call OnQueueReady? - - return types.StatusOK -} - -func (b *baseHost) GetQueueSize(queueID uint32) int { - return len(b.queues[queueID]) -} - -func (b *baseHost) ProxyGetSharedData(keyData *byte, keySize int, - returnValueData **byte, returnValueSize *int, returnCas *uint32) types.Status { - key := proxywasm.RawBytePtrToString(keyData, keySize) - - value, ok := b.sharedDataKVS[key] - if !ok { - return types.StatusNotFound - } - - *returnValueSize = len(value.data) - *returnValueData = &value.data[0] - *returnCas = value.cas - return types.StatusOK -} - -func (b *baseHost) ProxySetSharedData(keyData *byte, keySize int, - valueData *byte, valueSize int, cas uint32) types.Status { - key := proxywasm.RawBytePtrToString(keyData, keySize) - value := proxywasm.RawBytePtrToByteSlice(valueData, valueSize) - - prev, ok := b.sharedDataKVS[key] - if !ok { - b.sharedDataKVS[key] = &sharedData{ - data: value, - cas: cas + 1, - } - return types.StatusOK - } - - if prev.cas != cas { - return types.StatusCasMismatch - } - - b.sharedDataKVS[key].cas = cas + 1 - b.sharedDataKVS[key].data = value - return types.StatusOK -} - -func (b *baseHost) ProxyDefineMetric(metricType types.MetricType, - metricNameData *byte, metricNameSize int, returnMetricIDPtr *uint32) types.Status { - name := proxywasm.RawBytePtrToString(metricNameData, metricNameSize) - id, ok := b.metricNameToID[name] - if !ok { - id = uint32(len(b.metricNameToID)) - b.metricNameToID[name] = id - b.metricIDToValue[id] = 0 - b.metricIDToType[id] = metricType - } - *returnMetricIDPtr = id - return types.StatusOK -} - -func (b *baseHost) ProxyIncrementMetric(metricID uint32, offset int64) types.Status { - // TODO: check metric type - - val, ok := b.metricIDToValue[metricID] - if !ok { - return types.StatusBadArgument - } - - b.metricIDToValue[metricID] = val + uint64(offset) - return types.StatusOK -} - -func (b *baseHost) ProxyRecordMetric(metricID uint32, value uint64) types.Status { - // TODO: check metric type - - _, ok := b.metricIDToValue[metricID] - if !ok { - return types.StatusBadArgument - } - b.metricIDToValue[metricID] = value - return types.StatusOK -} - -func (b *baseHost) ProxyGetMetric(metricID uint32, returnMetricValue *uint64) types.Status { - value, ok := b.metricIDToValue[metricID] - if !ok { - return types.StatusBadArgument - } - *returnMetricValue = value - return types.StatusOK -} - -func (b *baseHost) getBuffer(bt types.BufferType, start int, maxSize int, - returnBufferData **byte, returnBufferSize *int) types.Status { - if bt != types.BufferTypeHttpCallResponseBody { - panic("unimplemented") - } - - res, ok := b.calloutResponse[b.currentContextID] - if !ok { - log.Fatalf("callout response unregistered for %d", b.currentContextID) - } - - *returnBufferData = &res.body[0] - *returnBufferSize = len(res.body) - return types.StatusOK -} - -func (b *baseHost) getMapValue(mapType types.MapType, keyData *byte, - keySize int, returnValueData **byte, returnValueSize *int) types.Status { - res, ok := b.calloutResponse[b.currentContextID] - if !ok { - log.Fatalf("callout response unregistered for %d", b.currentContextID) - } - - key := proxywasm.RawBytePtrToString(keyData, keySize) - - var hs [][2]string - switch mapType { - case types.MapTypeHttpCallResponseHeaders: - hs = res.headers - case types.MapTypeHttpCallResponseTrailers: - hs = res.trailers - default: - panic("unimplemented") - } - - for _, h := range hs { - if h[0] == key { - v := []byte(h[1]) - *returnValueData = &v[0] - *returnValueSize = len(v) - return types.StatusOK - } - } - - return types.StatusNotFound -} - -func (b *baseHost) ProxyHttpCall(upstreamData *byte, upstreamSize int, headerData *byte, headerSize int, bodyData *byte, - bodySize int, trailersData *byte, trailersSize int, timeout uint32, _ *uint32) types.Status { - upstream := proxywasm.RawBytePtrToString(upstreamData, upstreamSize) - body := proxywasm.RawBytePtrToString(bodyData, bodySize) - headers := proxywasm.DeserializeMap(proxywasm.RawBytePtrToByteSlice(headerData, headerSize)) - trailers := proxywasm.DeserializeMap(proxywasm.RawBytePtrToByteSlice(trailersData, trailersSize)) - - log.Printf("[http callout to %s] timeout: %d", upstream, timeout) - log.Printf("[http callout to %s] headers: %v", upstream, headers) - log.Printf("[http callout to %s] body: %s", upstream, body) - log.Printf("[http callout to %s] trailers: %v", upstream, trailers) - - b.callouts[b.currentContextID] = struct{}{} - return types.StatusOK -} - -func (b *baseHost) PutCalloutResponse(contextID uint32, headers, trailers [][2]string, body []byte) { - b.calloutResponse[contextID] = struct { - headers, trailers [][2]string - body []byte - }{headers: headers, trailers: trailers, body: body} - - b.currentContextID = contextID - b.calloutCallbackCaller(contextID, len(headers), len(body), len(trailers)) - delete(b.calloutResponse, contextID) -} - -func (b *baseHost) IsDispatchCalled(contextID uint32) bool { - _, ok := b.callouts[contextID] - return ok -} - -func (b *baseHost) getMapPairs(mapType types.MapType, returnValueData **byte, returnValueSize *int) types.Status { - res, ok := b.calloutResponse[b.currentContextID] - if !ok { - log.Fatalf("callout response unregistered for %d", b.currentContextID) - } - - var raw []byte - switch mapType { - case types.MapTypeHttpCallResponseHeaders: - raw = proxywasm.SerializeMap(res.headers) - case types.MapTypeHttpCallResponseTrailers: - raw = proxywasm.SerializeMap(res.trailers) - default: - panic("unimplemented") - } - - *returnValueData = &raw[0] - *returnValueSize = len(raw) - return types.StatusOK -} diff --git a/proxytest/base_test.go b/proxytest/base_test.go deleted file mode 100644 index cd0d75c4..00000000 --- a/proxytest/base_test.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2020 Tetrate -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxytest - -// TODO: diff --git a/proxytest/http.go b/proxytest/http.go index 0022cf2d..754b8bee 100644 --- a/proxytest/http.go +++ b/proxytest/http.go @@ -18,153 +18,48 @@ import ( "log" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -type HttpFilterHost struct { - *baseHost - - newContext func(contextID uint32) proxywasm.HttpContext - contexts map[uint32]*httpContextState -} - -type httpContextState struct { - context proxywasm.HttpContext - requestHeaders, responseHeaders, - requestTrailers, responseTrailers [][2]string - requestBody, responseBody []byte - - action types.Action - sentLocalResponse *LocalHttpResponse -} - -type LocalHttpResponse struct { - StatusCode uint32 - StatusCodeDetail string - Data []byte - Headers [][2]string - GRPCStatus int32 -} - -func NewHttpFilterHost(f func(contextID uint32) proxywasm.HttpContext) (*HttpFilterHost, func()) { - host := &HttpFilterHost{ - newContext: f, - contexts: map[uint32]*httpContextState{}, - } - - host.baseHost = newBaseHost(func(contextID uint32, numHeaders, bodySize, numTrailers int) { - ctx, ok := host.contexts[contextID] - if !ok { - log.Fatalf("invalid context id for callback: %d", contextID) - } - - ctx.context.OnHttpCallResponse(numHeaders, bodySize, numTrailers) - }) - hostMux.Lock() - rawhostcall.RegisterMockWASMHost(host) - return host, func() { - hostMux.Unlock() - } -} - -func (h *HttpFilterHost) InitContext() uint32 { - contextID := uint32(len(h.contexts)) + 1 - ctx := h.newContext(contextID) - - h.contexts[contextID] = &httpContextState{ - context: ctx, - action: types.ActionContinue, +type ( + httpHostEmulator struct { + httpStreams map[uint32]*httpStreamState } - return contextID -} + httpStreamState struct { + requestHeaders, responseHeaders, + requestTrailers, responseTrailers [][2]string + requestBody, responseBody []byte -func (h *HttpFilterHost) GetAction(contextID uint32) types.Action { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) + action types.Action + sentLocalResponse *LocalHttpResponse } - return cs.action -} - -func (h *HttpFilterHost) PutRequestHeaders(contextID uint32, headers [][2]string) { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) - } - - cs.requestHeaders = headers - h.currentContextID = contextID - cs.action = cs.context.OnHttpRequestHeaders(len(headers), false) // TODO: allow for specifying end_of_stream -} - -func (h *HttpFilterHost) PutResponseHeaders(contextID uint32, headers [][2]string) { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) - } - - cs.responseHeaders = headers - h.currentContextID = contextID - cs.action = cs.context.OnHttpResponseHeaders(len(headers), false) // TODO: allow for specifying end_of_stream -} - -func (h *HttpFilterHost) PutRequestTrailers(contextID uint32, headers [][2]string) { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) - } - - cs.requestTrailers = headers - h.currentContextID = contextID - cs.action = cs.context.OnHttpRequestTrailers(len(headers)) -} - -func (h *HttpFilterHost) PutResponseTrailers(contextID uint32, headers [][2]string) { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) - } - - cs.responseTrailers = headers - h.currentContextID = contextID - cs.action = cs.context.OnHttpResponseTrailers(len(headers)) -} - -func (h *HttpFilterHost) PutRequestBody(contextID uint32, body []byte) { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) + LocalHttpResponse struct { + StatusCode uint32 + StatusCodeDetail string + Data []byte + Headers [][2]string + GRPCStatus int32 } +) - cs.requestBody = body - h.currentContextID = contextID - cs.action = cs.context.OnHttpRequestBody(len(body), false) // TODO: allow for specifying end_of_stream +func newHttpHostEmulator() *httpHostEmulator { + host := &httpHostEmulator{httpStreams: map[uint32]*httpStreamState{}} + return host } -func (h *HttpFilterHost) PutResponseBody(contextID uint32, body []byte) { - cs, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) - } - - cs.responseBody = body - h.currentContextID = contextID - cs.action = cs.context.OnHttpResponseBody(len(body), false) // TODO: allow for specifying end_of_stream -} - -func (h *HttpFilterHost) ProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, +// impl host rawhostcall.ProxyWASMHost: delegated from hostEmulator +func (h *httpHostEmulator) httpHostEmulatorProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, returnBufferData **byte, returnBufferSize *int) types.Status { - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] var buf []byte switch bt { case types.BufferTypeHttpRequestBody: - buf = ctx.requestBody + buf = stream.requestBody case types.BufferTypeHttpResponseBody: - buf = ctx.requestBody + buf = stream.requestBody default: - // delegate to baseHost - return h.getBuffer(bt, start, maxSize, returnBufferData, returnBufferSize) + panic("unreachable: maybe a bug in this host emulation or SDK") } if start >= len(buf) { @@ -181,23 +76,25 @@ func (h *HttpFilterHost) ProxyGetBufferBytes(bt types.BufferType, start int, max return types.StatusOK } -func (h *HttpFilterHost) ProxyGetHeaderMapValue(mapType types.MapType, keyData *byte, +// impl host rawhostcall.ProxyWASMHost: delegated from hostEmulator +func (h *httpHostEmulator) httpHostEmulatorProxyGetHeaderMapValue(mapType types.MapType, keyData *byte, keySize int, returnValueData **byte, returnValueSize *int) types.Status { key := proxywasm.RawBytePtrToString(keyData, keySize) - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] var headers [][2]string switch mapType { case types.MapTypeHttpRequestHeaders: - headers = ctx.requestHeaders + headers = stream.requestHeaders case types.MapTypeHttpResponseHeaders: - headers = ctx.responseHeaders + headers = stream.responseHeaders case types.MapTypeHttpRequestTrailers: - headers = ctx.requestTrailers + headers = stream.requestTrailers case types.MapTypeHttpResponseTrailers: - headers = ctx.responseTrailers + headers = stream.responseTrailers default: - return h.getMapValue(mapType, keyData, keySize, returnValueData, returnValueSize) + panic("unreachable: maybe a bug in this host emulation or SDK") } for _, h := range headers { @@ -212,22 +109,24 @@ func (h *HttpFilterHost) ProxyGetHeaderMapValue(mapType types.MapType, keyData * return types.StatusNotFound } -func (h *HttpFilterHost) ProxyAddHeaderMapValue(mapType types.MapType, keyData *byte, +// impl host rawhostcall.ProxyWASMHost +func (h *httpHostEmulator) ProxyAddHeaderMapValue(mapType types.MapType, keyData *byte, keySize int, valueData *byte, valueSize int) types.Status { key := proxywasm.RawBytePtrToString(keyData, keySize) value := proxywasm.RawBytePtrToString(valueData, valueSize) - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] switch mapType { case types.MapTypeHttpRequestHeaders: - ctx.requestHeaders = addMapValue(ctx.requestHeaders, key, value) + stream.requestHeaders = addMapValue(stream.requestHeaders, key, value) case types.MapTypeHttpResponseHeaders: - ctx.responseHeaders = addMapValue(ctx.responseHeaders, key, value) + stream.responseHeaders = addMapValue(stream.responseHeaders, key, value) case types.MapTypeHttpRequestTrailers: - ctx.requestTrailers = addMapValue(ctx.requestTrailers, key, value) + stream.requestTrailers = addMapValue(stream.requestTrailers, key, value) case types.MapTypeHttpResponseTrailers: - ctx.responseTrailers = addMapValue(ctx.responseTrailers, key, value) + stream.responseTrailers = addMapValue(stream.responseTrailers, key, value) default: panic("unimplemented") } @@ -246,27 +145,30 @@ func addMapValue(base [][2]string, key, value string) [][2]string { return append(base, [2]string{key, value}) } -func (h *HttpFilterHost) ProxyReplaceHeaderMapValue(mapType types.MapType, keyData *byte, +// impl host rawhostcall.ProxyWASMHost +func (h *httpHostEmulator) ProxyReplaceHeaderMapValue(mapType types.MapType, keyData *byte, keySize int, valueData *byte, valueSize int) types.Status { key := proxywasm.RawBytePtrToString(keyData, keySize) value := proxywasm.RawBytePtrToString(valueData, valueSize) - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] switch mapType { case types.MapTypeHttpRequestHeaders: - ctx.requestHeaders = replaceMapValue(ctx.requestHeaders, key, value) + stream.requestHeaders = replaceMapValue(stream.requestHeaders, key, value) case types.MapTypeHttpResponseHeaders: - ctx.responseHeaders = replaceMapValue(ctx.responseHeaders, key, value) + stream.responseHeaders = replaceMapValue(stream.responseHeaders, key, value) case types.MapTypeHttpRequestTrailers: - ctx.requestTrailers = replaceMapValue(ctx.requestTrailers, key, value) + stream.requestTrailers = replaceMapValue(stream.requestTrailers, key, value) case types.MapTypeHttpResponseTrailers: - ctx.responseTrailers = replaceMapValue(ctx.responseTrailers, key, value) + stream.responseTrailers = replaceMapValue(stream.responseTrailers, key, value) default: panic("unimplemented") } return types.StatusOK } +// impl host rawhostcall.ProxyWASMHost func replaceMapValue(base [][2]string, key, value string) [][2]string { for i, h := range base { if h[0] == key { @@ -278,18 +180,21 @@ func replaceMapValue(base [][2]string, key, value string) [][2]string { return append(base, [2]string{key, value}) } -func (h *HttpFilterHost) ProxyRemoveHeaderMapValue(mapType types.MapType, keyData *byte, keySize int) types.Status { +// impl host rawhostcall.ProxyWASMHost +func (h *httpHostEmulator) ProxyRemoveHeaderMapValue(mapType types.MapType, keyData *byte, keySize int) types.Status { key := proxywasm.RawBytePtrToString(keyData, keySize) - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] + switch mapType { case types.MapTypeHttpRequestHeaders: - ctx.requestHeaders = removeHeaderMapValue(ctx.requestHeaders, key) + stream.requestHeaders = removeHeaderMapValue(stream.requestHeaders, key) case types.MapTypeHttpResponseHeaders: - ctx.responseHeaders = removeHeaderMapValue(ctx.responseHeaders, key) + stream.responseHeaders = removeHeaderMapValue(stream.responseHeaders, key) case types.MapTypeHttpRequestTrailers: - ctx.requestTrailers = removeHeaderMapValue(ctx.requestTrailers, key) + stream.requestTrailers = removeHeaderMapValue(stream.requestTrailers, key) case types.MapTypeHttpResponseTrailers: - ctx.responseTrailers = removeHeaderMapValue(ctx.responseTrailers, key) + stream.responseTrailers = removeHeaderMapValue(stream.responseTrailers, key) default: panic("unimplemented") } @@ -309,22 +214,24 @@ func removeHeaderMapValue(base [][2]string, key string) [][2]string { return base } -func (h *HttpFilterHost) ProxyGetHeaderMapPairs(mapType types.MapType, returnValueData **byte, +// impl host rawhostcall.ProxyWASMHost: delegated from hostEmulator +func (h *httpHostEmulator) httpHostEmulatorProxyGetHeaderMapPairs(mapType types.MapType, returnValueData **byte, returnValueSize *int) types.Status { - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] var m []byte switch mapType { case types.MapTypeHttpRequestHeaders: - m = proxywasm.SerializeMap(ctx.requestHeaders) + m = proxywasm.SerializeMap(stream.requestHeaders) case types.MapTypeHttpResponseHeaders: - m = proxywasm.SerializeMap(ctx.responseHeaders) + m = proxywasm.SerializeMap(stream.responseHeaders) case types.MapTypeHttpRequestTrailers: - m = proxywasm.SerializeMap(ctx.requestTrailers) + m = proxywasm.SerializeMap(stream.requestTrailers) case types.MapTypeHttpResponseTrailers: - m = proxywasm.SerializeMap(ctx.responseTrailers) + m = proxywasm.SerializeMap(stream.responseTrailers) default: - return h.getMapPairs(mapType, returnValueData, returnValueSize) + panic("unreachable: maybe a bug in this host emulation or SDK") } *returnValueData = &m[0] @@ -332,43 +239,42 @@ func (h *HttpFilterHost) ProxyGetHeaderMapPairs(mapType types.MapType, returnVal return types.StatusOK } -func (h *HttpFilterHost) ProxySetHeaderMapPairs(mapType types.MapType, mapData *byte, mapSize int) types.Status { +// impl host rawhostcall.ProxyWASMHost +func (h *httpHostEmulator) ProxySetHeaderMapPairs(mapType types.MapType, mapData *byte, mapSize int) types.Status { m := proxywasm.DeserializeMap(proxywasm.RawBytePtrToByteSlice(mapData, mapSize)) - ctx := h.contexts[h.currentContextID] + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] + switch mapType { case types.MapTypeHttpRequestHeaders: - ctx.requestHeaders = m + stream.requestHeaders = m case types.MapTypeHttpResponseHeaders: - ctx.responseHeaders = m + stream.responseHeaders = m case types.MapTypeHttpRequestTrailers: - ctx.requestTrailers = m + stream.requestTrailers = m case types.MapTypeHttpResponseTrailers: - ctx.responseTrailers = m + stream.responseTrailers = m default: panic("unimplemented") } return types.StatusOK } -func (h *HttpFilterHost) ProxyContinueStream(types.StreamType) types.Status { - ctx := h.contexts[h.currentContextID] - ctx.action = types.ActionContinue +// impl host rawhostcall.ProxyWASMHost +func (h *httpHostEmulator) ProxyContinueStream(types.StreamType) types.Status { + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] + stream.action = types.ActionContinue return types.StatusOK } -func (h *HttpFilterHost) GetCurrentAction(contextID uint32) types.Action { - ctx, ok := h.contexts[contextID] - if !ok { - log.Fatalf("invalid context id: %d", contextID) - } - return ctx.action -} - -func (h *HttpFilterHost) ProxySendLocalResponse(statusCode uint32, +// impl host rawhostcall.ProxyWASMHost +func (h *httpHostEmulator) ProxySendLocalResponse(statusCode uint32, statusCodeDetailData *byte, statusCodeDetailsSize int, bodyData *byte, bodySize int, headersData *byte, headersSize int, grpcStatus int32) types.Status { - ctx := h.contexts[h.currentContextID] - ctx.sentLocalResponse = &LocalHttpResponse{ + active := proxywasm.VMStateGetActiveContextID() + stream := h.httpStreams[active] + stream.sentLocalResponse = &LocalHttpResponse{ StatusCode: statusCode, StatusCodeDetail: proxywasm.RawBytePtrToString(statusCodeDetailData, statusCodeDetailsSize), Data: proxywasm.RawBytePtrToByteSlice(bodyData, bodySize), @@ -378,10 +284,100 @@ func (h *HttpFilterHost) ProxySendLocalResponse(statusCode uint32, return types.StatusOK } -func (h *HttpFilterHost) GetSentLocalResponse(contextID uint32) *LocalHttpResponse { - return h.contexts[contextID].sentLocalResponse +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterInitContext() (contextID uint32) { + contextID = getNextContextID() + proxywasm.ProxyOnContextCreate(contextID, rootContextID) + h.httpStreams[contextID] = &httpStreamState{action: types.ActionContinue} + return +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterPutRequestHeaders(contextID uint32, headers [][2]string) { + cs, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + + cs.requestHeaders = headers + cs.action = proxywasm.ProxyOnRequestHeaders(contextID, + len(headers), false) // TODO: allow for specifying end_of_stream +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterPutResponseHeaders(contextID uint32, headers [][2]string) { + cs, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + + cs.responseHeaders = headers + + cs.action = proxywasm.ProxyOnResponseHeaders(contextID, + len(headers), false) // TODO: allow for specifying end_of_stream +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterPutRequestTrailers(contextID uint32, headers [][2]string) { + cs, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + + cs.requestTrailers = headers + cs.action = proxywasm.ProxyOnRequestTrailers(contextID, len(headers)) +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterPutResponseTrailers(contextID uint32, headers [][2]string) { + cs, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + + cs.responseTrailers = headers + cs.action = proxywasm.ProxyOnResponseTrailers(contextID, len(headers)) +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterPutRequestBody(contextID uint32, body []byte) { + cs, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + + cs.requestBody = body + cs.action = proxywasm.ProxyOnRequestBody(contextID, + len(body), false) // TODO: allow for specifying end_of_stream +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterPutResponseBody(contextID uint32, body []byte) { + cs, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + + cs.responseBody = body + cs.action = proxywasm.ProxyOnResponseBody(contextID, + len(body), false) // TODO: allow for specifying end_of_stream +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterCompleteHttpStream(contextID uint32) { + proxywasm.ProxyOnDone(contextID) +} + +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterGetCurrentStreamAction(contextID uint32) types.Action { + stream, ok := h.httpStreams[contextID] + if !ok { + log.Fatalf("invalid context id: %d", contextID) + } + return stream.action } -func (h *HttpFilterHost) GetContext(contextID uint32) proxywasm.HttpContext { - return h.contexts[contextID].context +// impl host HostEmulator +func (h *httpHostEmulator) HttpFilterGetSentLocalResponse(contextID uint32) *LocalHttpResponse { + return h.httpStreams[contextID].sentLocalResponse } diff --git a/proxytest/http_test.go b/proxytest/http_test.go deleted file mode 100644 index cd0d75c4..00000000 --- a/proxytest/http_test.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2020 Tetrate -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxytest - -// TODO: diff --git a/proxytest/network.go b/proxytest/network.go index 92d00728..8a0ef658 100644 --- a/proxytest/network.go +++ b/proxytest/network.go @@ -18,43 +18,58 @@ import ( "log" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -type NetworkFilterHost struct { - *baseHost - newContext func(contextID uint32) proxywasm.StreamContext - streams map[uint32]*streamState +type networkHostEmulator struct { + streamStates map[uint32]*streamState } type streamState struct { upstream, downstream []byte - context proxywasm.StreamContext } -func NewNetworkFilterHost(f func(contextID uint32) proxywasm.StreamContext) (*NetworkFilterHost, func()) { - host := &NetworkFilterHost{ - newContext: f, - streams: map[uint32]*streamState{}, +func newNetworkHostEmulator() *networkHostEmulator { + host := &networkHostEmulator{ + streamStates: map[uint32]*streamState{}, } - host.baseHost = newBaseHost(func(contextID uint32, numHeaders, bodySize, numTrailers int) { - stream, ok := host.streams[contextID] - if !ok { - log.Fatalf("invalid context id for callback: %d", contextID) - } - stream.context.OnHttpCallResponse(numHeaders, bodySize, numTrailers) - }) - hostMux.Lock() // acquire the lock of host emulation - rawhostcall.RegisterMockWASMHost(host) - return host, func() { - hostMux.Unlock() + return host +} + +// impl host rawhostcall.ProxyWASMHost: delegated from hostEmulator +func (n *networkHostEmulator) networkHostEmulatorProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, + returnBufferData **byte, returnBufferSize *int) types.Status { + + active := proxywasm.VMStateGetActiveContextID() + stream := n.streamStates[active] + var buf []byte + switch bt { + case types.BufferTypeUpstreamData: + buf = stream.upstream + case types.BufferTypeDownstreamData: + buf = stream.downstream + default: + panic("unreachable: maybe a bug in this host emulation or SDK") + } + + if start >= len(buf) { + log.Printf("start index out of range: %d (start) >= %d ", start, len(buf)) + return types.StatusBadArgument + } + + *returnBufferData = &buf[start] + if maxSize > len(buf)-start { + *returnBufferSize = len(buf) - start + } else { + *returnBufferSize = maxSize } + return types.StatusOK } -func (n *NetworkFilterHost) PutUpstreamData(contextID uint32, data []byte) { - stream, ok := n.streams[contextID] +// impl host HostEmulator +func (n *networkHostEmulator) NetworkFilterPutUpstreamData(contextID uint32, data []byte) { + stream, ok := n.streamStates[contextID] if !ok { log.Fatalf("invalid context id: %d", contextID) } @@ -63,8 +78,7 @@ func (n *NetworkFilterHost) PutUpstreamData(contextID uint32, data []byte) { stream.upstream = append(stream.upstream, data...) } - n.currentContextID = contextID - action := stream.context.OnUpstreamData(len(stream.upstream), false) + action := proxywasm.ProxyOnUpstreamData(contextID, len(stream.upstream), false) switch action { case types.ActionPause: return @@ -76,8 +90,9 @@ func (n *NetworkFilterHost) PutUpstreamData(contextID uint32, data []byte) { } } -func (n *NetworkFilterHost) PutDownstreamData(contextID uint32, data []byte) { - stream, ok := n.streams[contextID] +// impl host HostEmulator +func (n *networkHostEmulator) NetworkFilterPutDownstreamData(contextID uint32, data []byte) { + stream, ok := n.streamStates[contextID] if !ok { log.Fatalf("invalid context id: %d", contextID) } @@ -85,8 +100,7 @@ func (n *NetworkFilterHost) PutDownstreamData(contextID uint32, data []byte) { stream.downstream = append(stream.downstream, data...) } - n.currentContextID = contextID - action := stream.context.OnDownstreamData(len(stream.downstream), false) + action := proxywasm.ProxyOnDownstreamData(contextID, len(stream.downstream), false) switch action { case types.ActionPause: return @@ -98,67 +112,27 @@ func (n *NetworkFilterHost) PutDownstreamData(contextID uint32, data []byte) { } } -func (n *NetworkFilterHost) InitConnection() (contextID uint32) { - contextID = uint32(len(n.streams) + 1) - ctx := n.newContext(contextID) - n.streams[contextID] = &streamState{context: ctx} - - n.currentContextID = contextID - ctx.OnNewConnection() +// impl host HostEmulator +func (n *networkHostEmulator) NetworkFilterInitConnection() (contextID uint32) { + contextID = getNextContextID() + proxywasm.ProxyOnContextCreate(contextID, rootContextID) + proxywasm.ProxyOnNewConnection(contextID) + n.streamStates[contextID] = &streamState{} return } -func (n *NetworkFilterHost) CloseUpstreamConnection(contextID uint32) { - n.streams[contextID].context.OnUpstreamClose(types.PeerTypeLocal) // peerType will be removed in the next ABI -} - -func (n *NetworkFilterHost) CloseDownstreamConnection(contextID uint32) { - n.streams[contextID].context.OnDownstreamClose(types.PeerTypeLocal) // peerType will be removed in the next ABI -} - -func (n *NetworkFilterHost) CompleteConnection(contextID uint32) { - n.streams[contextID].context.OnDone() - delete(n.streams, contextID) -} - -func (n *NetworkFilterHost) ProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, - returnBufferData **byte, returnBufferSize *int) types.Status { - stream := n.streams[n.currentContextID] - var buf []byte - switch bt { - case types.BufferTypeUpstreamData: - buf = stream.upstream - case types.BufferTypeDownstreamData: - buf = stream.downstream - default: - // delegate to baseHost - return n.getBuffer(bt, start, maxSize, returnBufferData, returnBufferSize) - } - - if start >= len(buf) { - log.Printf("start index out of range: %d (start) >= %d ", start, len(buf)) - return types.StatusBadArgument - } - - *returnBufferData = &buf[start] - if maxSize > len(buf)-start { - *returnBufferSize = len(buf) - start - } else { - *returnBufferSize = maxSize - } - return types.StatusOK -} - -func (n *NetworkFilterHost) ProxyGetHeaderMapValue(mapType types.MapType, keyData *byte, - keySize int, returnValueData **byte, returnValueSize *int) types.Status { - return n.getMapValue(mapType, keyData, keySize, returnValueData, returnValueSize) +// impl host HostEmulator +func (n *networkHostEmulator) NetworkFilterCloseUpstreamConnection(contextID uint32) { + proxywasm.ProxyOnUpstreamConnectionClose(contextID, types.PeerTypeLocal) // peerType will be removed in the next ABI } -func (n *NetworkFilterHost) ProxyGetHeaderMapPairs(mapType types.MapType, returnValueData **byte, - returnValueSize *int) types.Status { - return n.getMapPairs(mapType, returnValueData, returnValueSize) +// impl host HostEmulator +func (n *networkHostEmulator) NetworkFilterCloseDownstreamConnection(contextID uint32) { + proxywasm.ProxyOnDownstreamConnectionClose(contextID, types.PeerTypeLocal) // peerType will be removed in the next ABI } -func (n *NetworkFilterHost) GetContext(contextID uint32) proxywasm.StreamContext { - return n.streams[contextID].context +// impl host HostEmulator +func (n *networkHostEmulator) NetworkFilterCompleteConnection(contextID uint32) { + proxywasm.ProxyOnDone(contextID) + delete(n.streamStates, contextID) } diff --git a/proxytest/network_test.go b/proxytest/network_test.go deleted file mode 100644 index cd0d75c4..00000000 --- a/proxytest/network_test.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2020 Tetrate -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxytest - -// TODO: diff --git a/proxytest/option.go b/proxytest/option.go new file mode 100644 index 00000000..19090325 --- /dev/null +++ b/proxytest/option.go @@ -0,0 +1,39 @@ +package proxytest + +import "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" + +type EmulatorOption struct { + pluginConfiguration, vmConfiguration []byte + newRootContext func(uint32) proxywasm.RootContext + newStreamContext func(uint32) proxywasm.StreamContext + newHttpContext func(uint32) proxywasm.HttpContext +} + +func NewEmulatorOption() *EmulatorOption { + return &EmulatorOption{} +} + +func (o *EmulatorOption) WithNewRootContext(f func(uint32) proxywasm.RootContext) *EmulatorOption { + o.newRootContext = f + return o +} + +func (o *EmulatorOption) WithNewHttpContext(f func(uint32) proxywasm.HttpContext) *EmulatorOption { + o.newHttpContext = f + return o +} + +func (o *EmulatorOption) WithNewStreamContext(f func(uint32) proxywasm.StreamContext) *EmulatorOption { + o.newStreamContext = f + return o +} + +func (o *EmulatorOption) WithPluginConfiguration(data []byte) *EmulatorOption { + o.pluginConfiguration = data + return o +} + +func (o *EmulatorOption) WithVMConfiguration(data []byte) *EmulatorOption { + o.vmConfiguration = data + return o +} diff --git a/proxytest/proxytest.go b/proxytest/proxytest.go new file mode 100644 index 00000000..27d66f14 --- /dev/null +++ b/proxytest/proxytest.go @@ -0,0 +1,187 @@ +package proxytest + +import ( + "log" + "sync" + "time" + + "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" + "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" + "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" +) + +type HostEmulator interface { + Done() + + // Root + StartVM() + StartPlugin() + FinishVM() + + GetCalloutAttributesFromContext(contextID uint32) []HttpCalloutAttribute + PutCalloutResponse(contextID uint32, headers, trailers [][2]string, body []byte) + + GetLogs(level types.LogLevel) []string + GetTickPeriod() uint32 + GetQueueSize(queueID uint32) int + + // network + NetworkFilterInitConnection() (contextID uint32) + NetworkFilterPutUpstreamData(contextID uint32, data []byte) + NetworkFilterPutDownstreamData(contextID uint32, data []byte) + NetworkFilterCloseUpstreamConnection(contextID uint32) + NetworkFilterCloseDownstreamConnection(contextID uint32) + NetworkFilterCompleteConnection(contextID uint32) + + // http + HttpFilterInitContext() (contextID uint32) + HttpFilterPutRequestHeaders(contextID uint32, headers [][2]string) + HttpFilterPutResponseHeaders(contextID uint32, headers [][2]string) + HttpFilterPutRequestTrailers(contextID uint32, headers [][2]string) + HttpFilterPutResponseTrailers(contextID uint32, headers [][2]string) + HttpFilterPutRequestBody(contextID uint32, body []byte) + HttpFilterPutResponseBody(contextID uint32, body []byte) + HttpFilterCompleteHttpStream(contextID uint32) + HttpFilterGetCurrentStreamAction(contextID uint32) types.Action + HttpFilterGetSentLocalResponse(contextID uint32) *LocalHttpResponse +} + +const ( + rootContextID uint32 = 1 // TODO: support multiple rootContext +) + +var ( + hostMux = sync.Mutex{} + nextContextID = rootContextID + 1 +) + +func NewHostEmulator(opt *EmulatorOption) HostEmulator { + root := newRootHostEmulator(opt.pluginConfiguration, opt.vmConfiguration) + network := newNetworkHostEmulator() + http := newHttpHostEmulator() + emulator := &hostEmulator{ + root, + network, + http, + 0, + } + + hostMux.Lock() // acquire the lock of host emulation + rawhostcall.RegisterMockWASMHost(emulator) + + // set up state + proxywasm.SetNewRootContext(opt.newRootContext) + proxywasm.SetNewStreamContext(opt.newStreamContext) + proxywasm.SetNewHttpContext(opt.newHttpContext) + + // create root context: TODO: support multiple root contexts + proxywasm.ProxyOnContextCreate(rootContextID, 0) + + return emulator +} + +func getNextContextID() (ret uint32) { + ret = nextContextID + nextContextID++ + return +} + +type hostEmulator struct { + *rootHostEmulator + *networkHostEmulator + *httpHostEmulator + + effectiveContextID uint32 +} + +// impl host HostEmulator +func (*hostEmulator) Done() { + hostMux.Unlock() + proxywasm.VMStateReset() +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, + returnBufferData **byte, returnBufferSize *int) types.Status { + switch bt { + case types.BufferTypePluginConfiguration, types.BufferTypeVMConfiguration, types.BufferTypeHttpCallResponseBody: + return h.rootHostEmulatorProxyGetBufferBytes(bt, start, maxSize, returnBufferData, returnBufferSize) + case types.BufferTypeDownstreamData, types.BufferTypeUpstreamData: + return h.networkHostEmulatorProxyGetBufferBytes(bt, start, maxSize, returnBufferData, returnBufferSize) + case types.BufferTypeHttpRequestBody, types.BufferTypeHttpResponseBody: + return h.httpHostEmulatorProxyGetBufferBytes(bt, start, maxSize, returnBufferData, returnBufferSize) + default: + panic("unreachable: maybe a bug in this host emulation or SDK") + } +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyGetHeaderMapValue(mapType types.MapType, keyData *byte, + keySize int, returnValueData **byte, returnValueSize *int) types.Status { + switch mapType { + case types.MapTypeHttpRequestHeaders, types.MapTypeHttpResponseHeaders, + types.MapTypeHttpRequestTrailers, types.MapTypeHttpResponseTrailers: + return h.httpHostEmulatorProxyGetHeaderMapValue(mapType, keyData, + keySize, returnValueData, returnValueSize) + case types.MapTypeHttpCallResponseHeaders, types.MapTypeHttpCallResponseTrailers: + return h.rootHostEmulatorProxyGetMapValue(mapType, keyData, + keySize, returnValueData, returnValueSize) + default: + panic("unreachable: maybe a bug in this host emulation or SDK") + } +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyGetHeaderMapPairs(mapType types.MapType, returnValueData **byte, + returnValueSize *int) types.Status { + switch mapType { + case types.MapTypeHttpRequestHeaders, types.MapTypeHttpResponseHeaders, + types.MapTypeHttpRequestTrailers, types.MapTypeHttpResponseTrailers: + return h.httpHostEmulatorProxyGetHeaderMapPairs(mapType, returnValueData, returnValueSize) + case types.MapTypeHttpCallResponseHeaders, types.MapTypeHttpCallResponseTrailers: + return h.rootHostEmulatorProxyGetHeaderMapPairs(mapType, returnValueData, returnValueSize) + default: + panic("unreachable: maybe a bug in this host emulation or SDK") + } +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyGetCurrentTimeNanoseconds(returnTime *int64) types.Status { + *returnTime = time.Now().UnixNano() + return types.StatusOK +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxySetEffectiveContext(contextID uint32) types.Status { + h.effectiveContextID = contextID + return types.StatusOK +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxySetProperty(*byte, int, *byte, int) types.Status { + panic("unimplemented") +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyGetProperty(*byte, int, **byte, *int) types.Status { + log.Printf("ProxyGetProperty not implemented in the host emulator yet") + return 0 +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyResolveSharedQueue(vmIDData *byte, vmIDSize int, nameData *byte, nameSize int, returnID *uint32) types.Status { + log.Printf("ProxyResolveSharedQueue not implemented in the host emulator yet") + return 0 +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyCloseStream(streamType types.StreamType) types.Status { + log.Printf("ProxyCloseStream not implemented in the host emulator yet") + return 0 +} + +// impl host rawhostcall.ProxyWASMHost +func (h *hostEmulator) ProxyDone() types.Status { + log.Printf("ProxyDone not implemented in the host emulator yet") + return 0 +} diff --git a/proxytest/root.go b/proxytest/root.go index 1fc81537..74091904 100644 --- a/proxytest/root.go +++ b/proxytest/root.go @@ -16,60 +16,320 @@ package proxytest import ( "log" + "time" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -// TODO: simulate OnQueueReady, OnTick +type ( + rootHostEmulator struct { + logs [types.LogLevelMax][]string + tickPeriod uint32 -type RootFilterHost struct { - *baseHost - context proxywasm.RootContext + queues map[uint32][][]byte + queueNameID map[string]uint32 - pluginConfiguration, vmConfiguration []byte + sharedDataKVS map[string]*sharedData + + metricIDToValue map[uint32]uint64 + metricIDToType map[uint32]types.MetricType + metricNameToID map[string]uint32 + + httpContextIDToCalloutInfos map[uint32][]HttpCalloutAttribute // key: contextID + httpCalloutIDToContextID map[uint32]uint32 // key: calloutID + httpCalloutResponse map[uint32]struct { // key: calloutID + headers, trailers [][2]string + body []byte + } + + pluginConfiguration, vmConfiguration []byte + + activeCalloutID *uint32 + } + + HttpCalloutAttribute struct { + CalloutID uint32 + Upstream string + Headers, Trailers [][2]string + Body []byte + } +) + +type sharedData struct { + data []byte + cas uint32 } -func NewRootFilterHost(ctx proxywasm.RootContext, pluginConfiguration, vmConfiguration []byte, -) (*RootFilterHost, func()) { - host := &RootFilterHost{ - context: ctx, +func newRootHostEmulator(pluginConfiguration, vmConfiguration []byte) *rootHostEmulator { + host := &rootHostEmulator{ + queues: map[uint32][][]byte{}, + queueNameID: map[string]uint32{}, + sharedDataKVS: map[string]*sharedData{}, + metricIDToValue: map[uint32]uint64{}, + metricIDToType: map[uint32]types.MetricType{}, + metricNameToID: map[string]uint32{}, + httpContextIDToCalloutInfos: map[uint32][]HttpCalloutAttribute{}, + httpCalloutIDToContextID: map[uint32]uint32{}, + httpCalloutResponse: map[uint32]struct { + headers, trailers [][2]string + body []byte + }{}, + pluginConfiguration: pluginConfiguration, vmConfiguration: vmConfiguration, } + return host +} - host.baseHost = newBaseHost(func(contextID uint32, numHeaders, bodySize, numTrailers int) { - host.context.OnHttpCallResponse(numHeaders, bodySize, numTrailers) - }) - hostMux.Lock() // acquire the lock of host emulation - rawhostcall.RegisterMockWASMHost(host) - return host, func() { - hostMux.Unlock() +func (r *rootHostEmulator) ProxyLog(logLevel types.LogLevel, messageData *byte, messageSize int) types.Status { + str := proxywasm.RawBytePtrToString(messageData, messageSize) + + log.Printf("proxy_%s_log: %s", logLevel, str) + r.logs[logLevel] = append(r.logs[logLevel], str) + return types.StatusOK +} + +func (r *rootHostEmulator) ProxySetTickPeriodMilliseconds(period uint32) types.Status { + r.tickPeriod = period + + now := time.Now() + go func() { + for { + time.Sleep(time.Millisecond * time.Duration(r.tickPeriod)) + log.Printf("proxy_on_tick called: %v\n", time.Since(now)) + now = time.Now() + proxywasm.ProxyOnTick(rootContextID) + } + }() + return types.StatusOK +} + +func (r *rootHostEmulator) ProxyRegisterSharedQueue(nameData *byte, nameSize int, returnID *uint32) types.Status { + name := proxywasm.RawBytePtrToString(nameData, nameSize) + if id, ok := r.queueNameID[name]; ok { + *returnID = id + return types.StatusOK } + + id := uint32(len(r.queues)) + r.queues[id] = [][]byte{} + r.queueNameID[name] = id + *returnID = id + return types.StatusOK } -func (n *RootFilterHost) ConfigurePlugin() { - size := len(n.pluginConfiguration) - n.context.OnConfigure(size) +func (r *rootHostEmulator) ProxyDequeueSharedQueue(queueID uint32, returnValueData **byte, returnValueSize *int) types.Status { + queue, ok := r.queues[queueID] + if !ok { + log.Printf("queue %d is not found", queueID) + return types.StatusNotFound + } else if len(queue) == 0 { + log.Printf("queue %d is empty", queueID) + return types.StatusEmpty + } + + data := queue[0] + *returnValueData = &data[0] + *returnValueSize = len(data) + r.queues[queueID] = queue[1:] + return types.StatusOK } -func (n *RootFilterHost) StartVM() { - size := len(n.vmConfiguration) - n.context.OnVMStart(size) +func (r *rootHostEmulator) ProxyEnqueueSharedQueue(queueID uint32, valueData *byte, valueSize int) types.Status { + queue, ok := r.queues[queueID] + if !ok { + log.Printf("queue %d is not found", queueID) + return types.StatusNotFound + } + + r.queues[queueID] = append(queue, proxywasm.RawBytePtrToByteSlice(valueData, valueSize)) + + // note that this behavior is not accurate for some old host implementations: + // see: https://github.com/proxy-wasm/proxy-wasm-cpp-host/pull/36 + proxywasm.ProxyOnQueueReady(rootContextID, queueID) // Note that this behavior is not accurate on Istio before 1.8.x + return types.StatusOK +} + +func (r *rootHostEmulator) ProxyGetSharedData(keyData *byte, keySize int, + returnValueData **byte, returnValueSize *int, returnCas *uint32) types.Status { + key := proxywasm.RawBytePtrToString(keyData, keySize) + + value, ok := r.sharedDataKVS[key] + if !ok { + return types.StatusNotFound + } + + *returnValueSize = len(value.data) + *returnValueData = &value.data[0] + *returnCas = value.cas + return types.StatusOK +} + +func (r *rootHostEmulator) ProxySetSharedData(keyData *byte, keySize int, + valueData *byte, valueSize int, cas uint32) types.Status { + key := proxywasm.RawBytePtrToString(keyData, keySize) + value := proxywasm.RawBytePtrToByteSlice(valueData, valueSize) + + prev, ok := r.sharedDataKVS[key] + if !ok { + r.sharedDataKVS[key] = &sharedData{ + data: value, + cas: cas + 1, + } + return types.StatusOK + } + + if prev.cas != cas { + return types.StatusCasMismatch + } + + r.sharedDataKVS[key].cas = cas + 1 + r.sharedDataKVS[key].data = value + return types.StatusOK +} + +func (r *rootHostEmulator) ProxyDefineMetric(metricType types.MetricType, + metricNameData *byte, metricNameSize int, returnMetricIDPtr *uint32) types.Status { + name := proxywasm.RawBytePtrToString(metricNameData, metricNameSize) + id, ok := r.metricNameToID[name] + if !ok { + id = uint32(len(r.metricNameToID)) + r.metricNameToID[name] = id + r.metricIDToValue[id] = 0 + r.metricIDToType[id] = metricType + } + *returnMetricIDPtr = id + return types.StatusOK +} + +func (r *rootHostEmulator) ProxyIncrementMetric(metricID uint32, offset int64) types.Status { + val, ok := r.metricIDToValue[metricID] + if !ok { + return types.StatusBadArgument + } + + r.metricIDToValue[metricID] = val + uint64(offset) + return types.StatusOK } -func (n *RootFilterHost) ProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, +func (r *rootHostEmulator) ProxyRecordMetric(metricID uint32, value uint64) types.Status { + _, ok := r.metricIDToValue[metricID] + if !ok { + return types.StatusBadArgument + } + r.metricIDToValue[metricID] = value + return types.StatusOK +} + +func (r *rootHostEmulator) ProxyGetMetric(metricID uint32, returnMetricValue *uint64) types.Status { + value, ok := r.metricIDToValue[metricID] + if !ok { + return types.StatusBadArgument + } + *returnMetricValue = value + return types.StatusOK +} + +func (r *rootHostEmulator) ProxyHttpCall(upstreamData *byte, upstreamSize int, headerData *byte, headerSize int, bodyData *byte, + bodySize int, trailersData *byte, trailersSize int, timeout uint32, calloutIDPtr *uint32) types.Status { + upstream := proxywasm.RawBytePtrToString(upstreamData, upstreamSize) + body := proxywasm.RawBytePtrToString(bodyData, bodySize) + headers := proxywasm.DeserializeMap(proxywasm.RawBytePtrToByteSlice(headerData, headerSize)) + trailers := proxywasm.DeserializeMap(proxywasm.RawBytePtrToByteSlice(trailersData, trailersSize)) + + log.Printf("[http callout to %s] timeout: %d", upstream, timeout) + log.Printf("[http callout to %s] headers: %v", upstream, headers) + log.Printf("[http callout to %s] body: %s", upstream, body) + log.Printf("[http callout to %s] trailers: %v", upstream, trailers) + + calloutID := uint32(len(r.httpCalloutIDToContextID)) + contextID := proxywasm.VMStateGetActiveContextID() + r.httpCalloutIDToContextID[calloutID] = contextID + r.httpContextIDToCalloutInfos[contextID] = append(r.httpContextIDToCalloutInfos[contextID], HttpCalloutAttribute{ + CalloutID: calloutID, + Upstream: upstream, + Headers: headers, + Trailers: trailers, + }) + *calloutIDPtr = calloutID + return types.StatusOK +} + +// delegated from hostEmulator +func (r *rootHostEmulator) rootHostEmulatorProxyGetHeaderMapPairs(mapType types.MapType, returnValueData **byte, returnValueSize *int) types.Status { + activeID := proxywasm.VMStateGetActiveContextID() + res, ok := r.httpCalloutResponse[*r.activeCalloutID] + if !ok { + log.Fatalf("callout response unregistered for %d", activeID) + } + + var raw []byte + switch mapType { + case types.MapTypeHttpCallResponseHeaders: + raw = proxywasm.SerializeMap(res.headers) + case types.MapTypeHttpCallResponseTrailers: + raw = proxywasm.SerializeMap(res.trailers) + default: + panic("unreachable: maybe a bug in this host emulation or SDK") + } + + *returnValueData = &raw[0] + *returnValueSize = len(raw) + return types.StatusOK +} + +// delegated from hostEmulator +func (r *rootHostEmulator) rootHostEmulatorProxyGetMapValue(mapType types.MapType, keyData *byte, + keySize int, returnValueData **byte, returnValueSize *int) types.Status { + activeID := proxywasm.VMStateGetActiveContextID() + res, ok := r.httpCalloutResponse[*r.activeCalloutID] + if !ok { + log.Fatalf("callout response unregistered for %d", activeID) + } + + key := proxywasm.RawBytePtrToString(keyData, keySize) + + var hs [][2]string + switch mapType { + case types.MapTypeHttpCallResponseHeaders: + hs = res.headers + case types.MapTypeHttpCallResponseTrailers: + hs = res.trailers + default: + panic("unimplemented") + } + + for _, h := range hs { + if h[0] == key { + v := []byte(h[1]) + *returnValueData = &v[0] + *returnValueSize = len(v) + return types.StatusOK + } + } + + return types.StatusNotFound +} + +// delegated from hostEmulator +func (r *rootHostEmulator) rootHostEmulatorProxyGetBufferBytes(bt types.BufferType, start int, maxSize int, returnBufferData **byte, returnBufferSize *int) types.Status { var buf []byte switch bt { case types.BufferTypePluginConfiguration: - buf = n.pluginConfiguration + buf = r.pluginConfiguration case types.BufferTypeVMConfiguration: - buf = n.vmConfiguration + buf = r.vmConfiguration + case types.BufferTypeHttpCallResponseBody: + activeID := proxywasm.VMStateGetActiveContextID() + res, ok := r.httpCalloutResponse[*r.activeCalloutID] + if !ok { + log.Fatalf("callout response unregistered for %d", activeID) + } + buf = res.body default: - // delegate to baseHost - return n.getBuffer(bt, start, maxSize, returnBufferData, returnBufferSize) + panic("unreachable: maybe a bug in this host emulation or SDK") } if start >= len(buf) { @@ -85,3 +345,49 @@ func (n *RootFilterHost) ProxyGetBufferBytes(bt types.BufferType, start int, max } return types.StatusOK } + +func (r *rootHostEmulator) GetLogs(level types.LogLevel) []string { + if level >= types.LogLevelMax { + log.Fatalf("invalid log level: %d", level) + } + return r.logs[level] +} + +func (r *rootHostEmulator) GetTickPeriod() uint32 { + return r.tickPeriod +} + +func (r *rootHostEmulator) GetQueueSize(queueID uint32) int { + return len(r.queues[queueID]) +} + +func (r *rootHostEmulator) GetCalloutAttributesFromContext(contextID uint32) []HttpCalloutAttribute { + infos := r.httpContextIDToCalloutInfos[contextID] + return infos +} + +func (r *rootHostEmulator) StartVM() { + proxywasm.ProxyOnVMStart(rootContextID, len(r.vmConfiguration)) +} + +func (r *rootHostEmulator) StartPlugin() { + proxywasm.ProxyOnConfigure(rootContextID, len(r.pluginConfiguration)) +} + +func (r *rootHostEmulator) PutCalloutResponse(calloutID uint32, headers, trailers [][2]string, body []byte) { + r.httpCalloutResponse[calloutID] = struct { + headers, trailers [][2]string + body []byte + }{headers: headers, trailers: trailers, body: body} + + // rootContextID, calloutID uint32, numHeaders, bodySize, numTrailers in + r.activeCalloutID = &calloutID + proxywasm.ProxyOnHttpCallResponse(rootContextID, calloutID, len(headers), len(body), len(trailers)) + r.activeCalloutID = nil + delete(r.httpCalloutResponse, calloutID) + delete(r.httpCalloutIDToContextID, calloutID) +} + +func (r *rootHostEmulator) FinishVM() { + proxywasm.ProxyOnDone(rootContextID) +} diff --git a/proxywasm/abi_configuration.go b/proxywasm/abi_configuration.go index 063e803e..0765d19a 100644 --- a/proxywasm/abi_configuration.go +++ b/proxywasm/abi_configuration.go @@ -21,7 +21,7 @@ func proxyOnVMStart(rootContextID uint32, vmConfigurationSize int) bool { panic("invalid context on proxy_on_vm_start") } currentState.setActiveContextID(rootContextID) - return ctx.OnVMStart(vmConfigurationSize) + return ctx.context.OnVMStart(vmConfigurationSize) } //export proxy_on_configure @@ -31,5 +31,5 @@ func proxyOnConfigure(rootContextID uint32, pluginConfigurationSize int) bool { panic("invalid context on proxy_on_configure") } currentState.setActiveContextID(rootContextID) - return ctx.OnConfigure(pluginConfigurationSize) + return ctx.context.OnPluginStart(pluginConfigurationSize) } diff --git a/proxywasm/abi_configuration_test.go b/proxywasm/abi_configuration_test.go index 12fb3294..2dd06072 100644 --- a/proxywasm/abi_configuration_test.go +++ b/proxywasm/abi_configuration_test.go @@ -8,17 +8,17 @@ import ( ) type configurationContext struct { - DefaultContext - onVMStartCalled, onConfigureCalled bool + DefaultRootContext + onVMStartCalled, onPluginStartCalled bool } -func (c *configurationContext) OnVMStart(_ int) bool { +func (c *configurationContext) OnVMStart(int) bool { c.onVMStartCalled = true return true } -func (c *configurationContext) OnConfigure(_ int) bool { - c.onConfigureCalled = true +func (c *configurationContext) OnPluginStart(int) bool { + c.onPluginStartCalled = true return true } @@ -27,15 +27,15 @@ func Test_proxyOnVMStart(t *testing.T) { currentStateMux.Lock() defer currentStateMux.Unlock() - currentState = &state{rootContexts: map[uint32]RootContext{rID: &configurationContext{}}} + currentState = &state{rootContexts: map[uint32]*rootContextState{rID: {context: &configurationContext{}}}} proxyOnVMStart(rID, 0) - ctx, ok := currentState.rootContexts[rID].(*configurationContext) + ctx, ok := currentState.rootContexts[rID].context.(*configurationContext) require.True(t, ok) assert.True(t, ctx.onVMStartCalled) assert.Equal(t, rID, currentState.activeContextID) proxyOnConfigure(rID, 0) - assert.True(t, ctx.onConfigureCalled) + assert.True(t, ctx.onPluginStartCalled) assert.Equal(t, rID, currentState.activeContextID) } diff --git a/proxywasm/abi_l4.go b/proxywasm/abi_l4.go index 6c6ab1af..69d961c7 100644 --- a/proxywasm/abi_l4.go +++ b/proxywasm/abi_l4.go @@ -18,7 +18,7 @@ import "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" //export proxy_on_new_connection func proxyOnNewConnection(contextID uint32) types.Action { - ctx, ok := currentState.streamContexts[contextID] + ctx, ok := currentState.streams[contextID] if !ok { panic("invalid context") } @@ -28,7 +28,7 @@ func proxyOnNewConnection(contextID uint32) types.Action { //export proxy_on_downstream_data func proxyOnDownstreamData(contextID uint32, dataSize int, endOfStream bool) types.Action { - ctx, ok := currentState.streamContexts[contextID] + ctx, ok := currentState.streams[contextID] if !ok { panic("invalid context") } @@ -38,7 +38,7 @@ func proxyOnDownstreamData(contextID uint32, dataSize int, endOfStream bool) typ //export proxy_on_downstream_connection_close func proxyOnDownstreamConnectionClose(contextID uint32, pType types.PeerType) { - ctx, ok := currentState.streamContexts[contextID] + ctx, ok := currentState.streams[contextID] if !ok { panic("invalid context") } @@ -48,7 +48,7 @@ func proxyOnDownstreamConnectionClose(contextID uint32, pType types.PeerType) { //export proxy_on_upstream_data func proxyOnUpstreamData(contextID uint32, dataSize int, endOfStream bool) types.Action { - ctx, ok := currentState.streamContexts[contextID] + ctx, ok := currentState.streams[contextID] if !ok { panic("invalid context") } @@ -58,7 +58,7 @@ func proxyOnUpstreamData(contextID uint32, dataSize int, endOfStream bool) types //export proxy_on_upstream_connection_close func proxyOnUpstreamConnectionClose(contextID uint32, pType types.PeerType) { - ctx, ok := currentState.streamContexts[contextID] + ctx, ok := currentState.streams[contextID] if !ok { panic("invalid context") } diff --git a/proxywasm/abi_l4_test.go b/proxywasm/abi_l4_test.go index 4d434f19..80677c49 100644 --- a/proxywasm/abi_l4_test.go +++ b/proxywasm/abi_l4_test.go @@ -10,7 +10,7 @@ import ( ) type l4Context struct { - DefaultContext + DefaultStreamContext onDownstreamData, onDownStreamClose, onNewConnection, @@ -44,8 +44,8 @@ func Test_l4(t *testing.T) { currentStateMux.Lock() defer currentStateMux.Unlock() - currentState = &state{streamContexts: map[uint32]StreamContext{cID: &l4Context{}}} - ctx, ok := currentState.streamContexts[cID].(*l4Context) + currentState = &state{streams: map[uint32]StreamContext{cID: &l4Context{}}} + ctx, ok := currentState.streams[cID].(*l4Context) require.True(t, ok) proxyOnNewConnection(cID) diff --git a/proxywasm/abi_l7.go b/proxywasm/abi_l7.go index d1b2dd09..8e523f90 100644 --- a/proxywasm/abi_l7.go +++ b/proxywasm/abi_l7.go @@ -20,7 +20,7 @@ import ( //export proxy_on_request_headers func proxyOnRequestHeaders(contextID uint32, numHeaders int, endOfStream bool) types.Action { - ctx, ok := currentState.httpContexts[contextID] + ctx, ok := currentState.httpStreams[contextID] if !ok { panic("invalid context on proxy_on_request_headers") } @@ -31,7 +31,7 @@ func proxyOnRequestHeaders(contextID uint32, numHeaders int, endOfStream bool) t //export proxy_on_request_body func proxyOnRequestBody(contextID uint32, bodySize int, endOfStream bool) types.Action { - ctx, ok := currentState.httpContexts[contextID] + ctx, ok := currentState.httpStreams[contextID] if !ok { panic("invalid context on proxy_on_request_body") } @@ -41,7 +41,7 @@ func proxyOnRequestBody(contextID uint32, bodySize int, endOfStream bool) types. //export proxy_on_request_trailers func proxyOnRequestTrailers(contextID uint32, numTrailers int) types.Action { - ctx, ok := currentState.httpContexts[contextID] + ctx, ok := currentState.httpStreams[contextID] if !ok { panic("invalid context on proxy_on_request_trailers") } @@ -51,7 +51,7 @@ func proxyOnRequestTrailers(contextID uint32, numTrailers int) types.Action { //export proxy_on_response_headers func proxyOnResponseHeaders(contextID uint32, numHeaders int, endOfStream bool) types.Action { - ctx, ok := currentState.httpContexts[contextID] + ctx, ok := currentState.httpStreams[contextID] if !ok { panic("invalid context id on proxy_on_response_headers") } @@ -61,7 +61,7 @@ func proxyOnResponseHeaders(contextID uint32, numHeaders int, endOfStream bool) //export proxy_on_response_body func proxyOnResponseBody(contextID uint32, bodySize int, endOfStream bool) types.Action { - ctx, ok := currentState.httpContexts[contextID] + ctx, ok := currentState.httpStreams[contextID] if !ok { panic("invalid context id on proxy_on_response_headers") } @@ -71,7 +71,7 @@ func proxyOnResponseBody(contextID uint32, bodySize int, endOfStream bool) types //export proxy_on_response_trailers func proxyOnResponseTrailers(contextID uint32, numTrailers int) types.Action { - ctx, ok := currentState.httpContexts[contextID] + ctx, ok := currentState.httpStreams[contextID] if !ok { panic("invalid context id on proxy_on_response_headers") } @@ -80,27 +80,19 @@ func proxyOnResponseTrailers(contextID uint32, numTrailers int) types.Action { } //export proxy_on_http_call_response -func proxyOnHttpCallResponse(_, calloutID uint32, numHeaders, bodySize, numTrailers int) { - ctxID, ok := currentState.callOuts[calloutID] +func proxyOnHttpCallResponse(rootContextID, calloutID uint32, numHeaders, bodySize, numTrailers int) { + root, ok := currentState.rootContexts[rootContextID] if !ok { - panic("invalid callout id") + panic("http_call_response on invalid root context") } - delete(currentState.callOuts, calloutID) - - if ctx, ok := currentState.streamContexts[ctxID]; ok { - currentState.setActiveContextID(ctxID) - hostCallSetEffectiveContext(ctxID) - ctx.OnHttpCallResponse(numHeaders, bodySize, numTrailers) - } else if ctx, ok := currentState.httpContexts[ctxID]; ok { - currentState.setActiveContextID(ctxID) - hostCallSetEffectiveContext(ctxID) - ctx.OnHttpCallResponse(numHeaders, bodySize, numTrailers) - } else if ctx, ok := currentState.rootContexts[ctxID]; ok { - currentState.setActiveContextID(ctxID) - hostCallSetEffectiveContext(ctxID) - ctx.OnHttpCallResponse(numHeaders, bodySize, numTrailers) - } else { - panic("invalid context on proxy_on_http_call_response") + cb := root.httpCallbacks[calloutID] + if cb == nil { + panic("invalid callout id") } + + SetEffectiveContext(cb.callerContextID) + currentState.setActiveContextID(cb.callerContextID) + delete(root.httpCallbacks, calloutID) + cb.callback(numHeaders, bodySize, numTrailers) } diff --git a/proxywasm/abi_l7_test.go b/proxywasm/abi_l7_test.go index 3b032550..730a3693 100644 --- a/proxywasm/abi_l7_test.go +++ b/proxywasm/abi_l7_test.go @@ -3,16 +3,15 @@ package proxywasm import ( "testing" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) type l7Context struct { - DefaultContext + DefaultHttpContext onHttpRequestHeaders, onHttpRequestBody, onHttpRequestTrailers, @@ -61,8 +60,8 @@ func Test_l7(t *testing.T) { currentStateMux.Lock() defer currentStateMux.Unlock() - currentState = &state{httpContexts: map[uint32]HttpContext{cID: &l7Context{}}} - ctx, ok := currentState.httpContexts[cID].(*l7Context) + currentState = &state{httpStreams: map[uint32]HttpContext{cID: &l7Context{}}} + ctx, ok := currentState.httpStreams[cID].(*l7Context) require.True(t, ok) proxyOnRequestHeaders(cID, 0, false) @@ -85,8 +84,8 @@ func Test_proxyOnHttpCallResponse(t *testing.T) { rawhostcall.RegisterMockWASMHost(rawhostcall.DefaultProxyWAMSHost{}) var ( - ctxID uint32 = 100 - callOutID uint32 = 10 + rootContextID uint32 = 1 + callOutID uint32 = 10 ) currentStateMux.Lock() @@ -94,34 +93,31 @@ func Test_proxyOnHttpCallResponse(t *testing.T) { ctx := &l7Context{} currentState = &state{ - rootContexts: map[uint32]RootContext{ctxID: ctx}, - callOuts: map[uint32]uint32{callOutID: ctxID}, - } - - proxyOnHttpCallResponse(0, callOutID, 0, 0, 0) - _, ok := currentState.callOuts[callOutID] - require.False(t, ok) - assert.True(t, ctx.onHttpCallResponse) - - ctx = &l7Context{} - currentState = &state{ - httpContexts: map[uint32]HttpContext{ctxID: ctx}, - callOuts: map[uint32]uint32{callOutID: ctxID}, + rootContexts: map[uint32]*rootContextState{rootContextID: { + httpCallbacks: map[uint32]*struct { + callback HttpCalloutCallBack + callerContextID uint32 + }{callOutID: {callback: ctx.OnHttpCallResponse}}, + }}, } - proxyOnHttpCallResponse(0, callOutID, 0, 0, 0) - _, ok = currentState.callOuts[callOutID] + proxyOnHttpCallResponse(rootContextID, callOutID, 0, 0, 0) + _, ok := currentState.rootContexts[rootContextID].httpCallbacks[callOutID] require.False(t, ok) assert.True(t, ctx.onHttpCallResponse) ctx = &l7Context{} currentState = &state{ - streamContexts: map[uint32]StreamContext{ctxID: ctx}, - callOuts: map[uint32]uint32{callOutID: ctxID}, + rootContexts: map[uint32]*rootContextState{rootContextID: { + httpCallbacks: map[uint32]*struct { + callback HttpCalloutCallBack + callerContextID uint32 + }{callOutID: {callback: ctx.OnHttpCallResponse}}, + }}, } - proxyOnHttpCallResponse(0, callOutID, 0, 0, 0) - _, ok = currentState.callOuts[callOutID] + proxyOnHttpCallResponse(rootContextID, callOutID, 0, 0, 0) + _, ok = currentState.rootContexts[rootContextID].httpCallbacks[callOutID] require.False(t, ok) assert.True(t, ctx.onHttpCallResponse) } diff --git a/proxywasm/abi_lifecycle.go b/proxywasm/abi_lifecycle.go index 579f0802..4e16643f 100644 --- a/proxywasm/abi_lifecycle.go +++ b/proxywasm/abi_lifecycle.go @@ -29,45 +29,18 @@ func proxyOnContextCreate(contextID uint32, rootContextID uint32) { //export proxy_on_done func proxyOnDone(contextID uint32) bool { - if ctx, ok := currentState.streamContexts[contextID]; ok { + if ctx, ok := currentState.streams[contextID]; ok { currentState.setActiveContextID(contextID) - return ctx.OnDone() - } else if ctx, ok := currentState.httpContexts[contextID]; ok { + ctx.OnStreamDone() + return true + } else if ctx, ok := currentState.httpStreams[contextID]; ok { currentState.setActiveContextID(contextID) - return ctx.OnDone() + ctx.OnHttpStreamDone() + return true } else if ctx, ok := currentState.rootContexts[contextID]; ok { currentState.setActiveContextID(contextID) - return ctx.OnDone() + return ctx.context.OnVMDone() } else { panic("invalid context on proxy_on_done") } } - -//export proxy_on_log -func proxyOnLog(contextID uint32) { - if ctx, ok := currentState.streamContexts[contextID]; ok { - currentState.setActiveContextID(contextID) - ctx.OnLog() - } else if ctx, ok := currentState.httpContexts[contextID]; ok { - currentState.setActiveContextID(contextID) - ctx.OnLog() - } else if ctx, ok := currentState.rootContexts[contextID]; ok { - currentState.setActiveContextID(contextID) - ctx.OnLog() - } else { - panic("invalid context on proxy_on_log") - } -} - -//export proxy_on_delete -func proxyOnDelete(contextID uint32) { - if _, ok := currentState.streamContexts[contextID]; ok { - delete(currentState.streamContexts, contextID) - } else if _, ok := currentState.httpContexts[contextID]; ok { - delete(currentState.httpContexts, contextID) - } else if _, ok := currentState.rootContexts[contextID]; ok { - delete(currentState.rootContexts, contextID) - } else { - panic("invalid context on proxy_on_delete") - } -} diff --git a/proxywasm/abi_lifecycle_test.go b/proxywasm/abi_lifecycle_test.go index 471cd508..cb818dd5 100644 --- a/proxywasm/abi_lifecycle_test.go +++ b/proxywasm/abi_lifecycle_test.go @@ -14,9 +14,10 @@ func Test_proxyOnContextCreate(t *testing.T) { var cnt int currentState = &state{ - rootContexts: map[uint32]RootContext{}, - httpContexts: map[uint32]HttpContext{}, - streamContexts: map[uint32]StreamContext{}, + rootContexts: map[uint32]*rootContextState{}, + httpStreams: map[uint32]HttpContext{}, + streams: map[uint32]StreamContext{}, + contextIDToRooID: map[uint32]uint32{}, } SetNewRootContext(func(contextID uint32) RootContext { @@ -42,42 +43,24 @@ func Test_proxyOnContextCreate(t *testing.T) { require.Equal(t, 1101, cnt) } -func Test_proxyOnDelete(t *testing.T) { - currentStateMux.Lock() - defer currentStateMux.Unlock() - - currentState = &state{ - rootContexts: map[uint32]RootContext{}, - httpContexts: map[uint32]HttpContext{}, - streamContexts: map[uint32]StreamContext{}, - } - - var id uint32 = 100 - var ctx = &DefaultContext{} - currentState.streamContexts[id] = ctx - proxyOnDelete(id) - assert.Nil(t, currentState.streamContexts[id]) - - currentState.httpContexts[id] = ctx - proxyOnDelete(id) - assert.Nil(t, currentState.httpContexts[id]) - - currentState.rootContexts[id] = ctx - proxyOnDelete(id) - assert.Nil(t, currentState.rootContexts[id]) +type lifecycleContext struct { + DefaultRootContext + DefaultHttpContext + DefaultStreamContext + onStreamDone, onHttpStreamDone, onVMDone bool } -type lifecycleContext struct { - DefaultContext - onDone, onLog bool +func (ctx *lifecycleContext) OnVMDone() bool { + ctx.onVMDone = true + return true } -func (ctx *lifecycleContext) OnLog() { - ctx.onLog = true +func (ctx *lifecycleContext) OnStreamDone() { + ctx.onStreamDone = true } -func (ctx *lifecycleContext) OnDone() bool { - ctx.onDone = true - return true + +func (ctx *lifecycleContext) OnHttpStreamDone() { + ctx.onHttpStreamDone = true } func Test_onDone(t *testing.T) { @@ -85,61 +68,22 @@ func Test_onDone(t *testing.T) { defer currentStateMux.Unlock() currentState = &state{ - rootContexts: map[uint32]RootContext{}, - httpContexts: map[uint32]HttpContext{}, - streamContexts: map[uint32]StreamContext{}, + rootContexts: map[uint32]*rootContextState{}, + httpStreams: map[uint32]HttpContext{}, + streams: map[uint32]StreamContext{}, } var id uint32 = 1 ctx := &lifecycleContext{} - currentState.rootContexts[id] = ctx + currentState.httpStreams[id] = ctx proxyOnDone(id) - assert.True(t, ctx.onDone) + assert.True(t, ctx.onHttpStreamDone) assert.Equal(t, id, currentState.activeContextID) id = 2 ctx = &lifecycleContext{} - currentState.httpContexts[id] = ctx + currentState.streams[id] = ctx proxyOnDone(id) - assert.True(t, ctx.onDone) - assert.Equal(t, id, currentState.activeContextID) - - id = 3 - ctx = &lifecycleContext{} - currentState.rootContexts[id] = ctx - proxyOnDone(id) - assert.True(t, ctx.onDone) - assert.Equal(t, id, currentState.activeContextID) -} - -func Test_onLog(t *testing.T) { - currentStateMux.Lock() - defer currentStateMux.Unlock() - - currentState = &state{ - rootContexts: map[uint32]RootContext{}, - httpContexts: map[uint32]HttpContext{}, - streamContexts: map[uint32]StreamContext{}, - } - - var id uint32 = 1 - ctx := &lifecycleContext{} - currentState.rootContexts[id] = ctx - proxyOnLog(id) - assert.True(t, ctx.onLog) - assert.Equal(t, id, currentState.activeContextID) - - id = 2 - ctx = &lifecycleContext{} - currentState.httpContexts[id] = ctx - proxyOnLog(id) - assert.True(t, ctx.onLog) - assert.Equal(t, id, currentState.activeContextID) - - id = 3 - ctx = &lifecycleContext{} - currentState.rootContexts[id] = ctx - proxyOnLog(id) - assert.True(t, ctx.onLog) + assert.True(t, ctx.onStreamDone) assert.Equal(t, id, currentState.activeContextID) } diff --git a/proxywasm/abi_queue.go b/proxywasm/abi_queue.go index 1ed551fe..f136bbac 100644 --- a/proxywasm/abi_queue.go +++ b/proxywasm/abi_queue.go @@ -22,6 +22,5 @@ func proxyOnQueueReady(contextID, queueID uint32) { } currentState.setActiveContextID(contextID) - hostCallSetEffectiveContext(contextID) - ctx.OnQueueReady(queueID) + ctx.context.OnQueueReady(queueID) } diff --git a/proxywasm/abi_queue_test.go b/proxywasm/abi_queue_test.go index 2124aff5..5231ae72 100644 --- a/proxywasm/abi_queue_test.go +++ b/proxywasm/abi_queue_test.go @@ -9,7 +9,7 @@ import ( ) type queueContext struct { - DefaultContext + DefaultRootContext onQueueReady bool } @@ -22,8 +22,8 @@ func Test_queueReady(t *testing.T) { currentStateMux.Lock() defer currentStateMux.Unlock() - currentState = &state{rootContexts: map[uint32]RootContext{id: &queueContext{}}} - ctx, ok := currentState.rootContexts[id].(*queueContext) + currentState = &state{rootContexts: map[uint32]*rootContextState{id: {context: &queueContext{}}}} + ctx, ok := currentState.rootContexts[id].context.(*queueContext) require.True(t, ok) proxyOnQueueReady(id, 10) assert.True(t, ctx.onQueueReady) diff --git a/proxywasm/abi_test_export.go b/proxywasm/abi_test_export.go new file mode 100644 index 00000000..264eab57 --- /dev/null +++ b/proxywasm/abi_test_export.go @@ -0,0 +1,93 @@ +// Copyright 2020 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build proxytest + +package proxywasm + +import "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" + +// this file exists only for proxytest package which is used with the `proxytest` build tag. +// Therefore, these functions are not included in a resulting WASM binary + +func ProxyOnVMStart(rootContextID uint32, vmConfigurationSize int) bool { + return proxyOnVMStart(rootContextID, vmConfigurationSize) +} + +func ProxyOnConfigure(rootContextID uint32, vmConfigurationSize int) bool { + return proxyOnConfigure(rootContextID, vmConfigurationSize) +} +func ProxyOnNewConnection(contextID uint32) types.Action { + return proxyOnNewConnection(contextID) +} + +func ProxyOnDownstreamData(contextID uint32, dataSize int, endOfStream bool) types.Action { + return proxyOnDownstreamData(contextID, dataSize, endOfStream) +} + +func ProxyOnDownstreamConnectionClose(contextID uint32, pType types.PeerType) { + proxyOnDownstreamConnectionClose(contextID, pType) +} + +func ProxyOnUpstreamData(contextID uint32, dataSize int, endOfStream bool) types.Action { + return proxyOnUpstreamData(contextID, dataSize, endOfStream) +} + +func ProxyOnUpstreamConnectionClose(contextID uint32, pType types.PeerType) { + proxyOnUpstreamConnectionClose(contextID, pType) +} + +func ProxyOnRequestHeaders(contextID uint32, numHeaders int, endOfStream bool) types.Action { + return proxyOnRequestHeaders(contextID, numHeaders, endOfStream) +} + +func ProxyOnRequestBody(contextID uint32, bodySize int, endOfStream bool) types.Action { + return proxyOnRequestBody(contextID, bodySize, endOfStream) +} + +func ProxyOnRequestTrailers(contextID uint32, numTrailers int) types.Action { + return proxyOnRequestTrailers(contextID, numTrailers) +} + +func ProxyOnResponseHeaders(contextID uint32, numHeaders int, endOfStream bool) types.Action { + return proxyOnResponseHeaders(contextID, numHeaders, endOfStream) +} + +func ProxyOnResponseBody(contextID uint32, bodySize int, endOfStream bool) types.Action { + return proxyOnResponseBody(contextID, bodySize, endOfStream) +} + +func ProxyOnResponseTrailers(contextID uint32, numTrailers int) types.Action { + return proxyOnResponseTrailers(contextID, numTrailers) +} + +func ProxyOnHttpCallResponse(rootContextID, calloutID uint32, numHeaders, bodySize, numTrailers int) { + proxyOnHttpCallResponse(rootContextID, calloutID, numHeaders, bodySize, numTrailers) +} + +func ProxyOnContextCreate(contextID uint32, rootContextID uint32) { + proxyOnContextCreate(contextID, rootContextID) +} + +func ProxyOnDone(contextID uint32) bool { + return proxyOnDone(contextID) +} + +func ProxyOnQueueReady(contextID, queueID uint32) { + proxyOnQueueReady(contextID, queueID) +} + +func ProxyOnTick(rootContextID uint32) { + proxyOnTick(rootContextID) +} diff --git a/proxywasm/abi_timers.go b/proxywasm/abi_timers.go index 5589c5cf..208b5173 100644 --- a/proxywasm/abi_timers.go +++ b/proxywasm/abi_timers.go @@ -20,6 +20,5 @@ func proxyOnTick(rootContextID uint32) { if !ok { panic("invalid root_context_id") } - currentState.setActiveContextID(rootContextID) - ctx.OnTick() + ctx.context.OnTick() } diff --git a/proxywasm/abi_timers_test.go b/proxywasm/abi_timers_test.go index 62840eb3..fc65c715 100644 --- a/proxywasm/abi_timers_test.go +++ b/proxywasm/abi_timers_test.go @@ -8,7 +8,7 @@ import ( ) type timerContext struct { - DefaultContext + DefaultRootContext onTick bool } @@ -21,8 +21,8 @@ func Test_onTick(t *testing.T) { currentStateMux.Lock() defer currentStateMux.Unlock() - currentState = &state{rootContexts: map[uint32]RootContext{id: &timerContext{}}} - ctx, ok := currentState.rootContexts[id].(*timerContext) + currentState = &state{rootContexts: map[uint32]*rootContextState{id: {context: &timerContext{}}}} + ctx, ok := currentState.rootContexts[id].context.(*timerContext) require.True(t, ok) proxyOnTick(id) assert.True(t, ctx.onTick) diff --git a/proxywasm/context.go b/proxywasm/context.go index f871bcc9..1005fbe2 100644 --- a/proxywasm/context.go +++ b/proxywasm/context.go @@ -18,70 +18,65 @@ import ( "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) -type Context interface { - OnDone() bool - OnHttpCallResponse(numHeaders, bodySize, numTrailers int) - OnLog() -} - type RootContext interface { - Context - OnConfigure(pluginConfigurationSize int) bool OnQueueReady(queueID uint32) OnTick() OnVMStart(vmConfigurationSize int) bool + OnPluginStart(pluginConfigurationSize int) bool + OnVMDone() bool } type StreamContext interface { - Context OnDownstreamData(dataSize int, endOfStream bool) types.Action OnDownstreamClose(peerType types.PeerType) OnNewConnection() types.Action OnUpstreamData(dataSize int, endOfStream bool) types.Action OnUpstreamClose(peerType types.PeerType) + OnStreamDone() } type HttpContext interface { - Context OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action OnHttpRequestBody(bodySize int, endOfStream bool) types.Action OnHttpRequestTrailers(numTrailers int) types.Action OnHttpResponseHeaders(numHeaders int, endOfStream bool) types.Action OnHttpResponseBody(bodySize int, endOfStream bool) types.Action OnHttpResponseTrailers(numTrailers int) types.Action + OnHttpStreamDone() } -type DefaultContext struct{} +type ( + DefaultRootContext struct{} + DefaultStreamContext struct{} + DefaultHttpContext struct{} +) var ( - _ Context = DefaultContext{} - _ RootContext = DefaultContext{} - _ StreamContext = DefaultContext{} - _ HttpContext = DefaultContext{} + _ RootContext = &DefaultRootContext{} + _ StreamContext = &DefaultStreamContext{} + _ HttpContext = &DefaultHttpContext{} ) -// impl Context -func (d DefaultContext) OnDone() bool { return true } -func (d DefaultContext) OnHttpCallResponse(int, int, int) {} -func (d DefaultContext) OnLog() {} - // impl RootContext -func (d DefaultContext) OnConfigure(int) bool { return true } -func (d DefaultContext) OnQueueReady(uint32) {} -func (d DefaultContext) OnTick() {} -func (d DefaultContext) OnVMStart(int) bool { return true } +func (*DefaultRootContext) OnQueueReady(uint32) {} +func (*DefaultRootContext) OnTick() {} +func (*DefaultRootContext) OnVMStart(int) bool { return true } +func (*DefaultRootContext) OnPluginStart(int) bool { return true } +func (*DefaultRootContext) OnVMDone() bool { return true } // impl StreamContext -func (d DefaultContext) OnDownstreamData(int, bool) types.Action { return types.ActionContinue } -func (d DefaultContext) OnDownstreamClose(types.PeerType) {} -func (d DefaultContext) OnNewConnection() types.Action { return types.ActionContinue } -func (d DefaultContext) OnUpstreamData(int, bool) types.Action { return types.ActionContinue } -func (d DefaultContext) OnUpstreamClose(types.PeerType) {} +func (*DefaultStreamContext) OnDownstreamData(int, bool) types.Action { return types.ActionContinue } +func (*DefaultStreamContext) OnDownstreamClose(types.PeerType) {} +func (*DefaultStreamContext) OnNewConnection() types.Action { return types.ActionContinue } +func (*DefaultStreamContext) OnUpstreamData(int, bool) types.Action { return types.ActionContinue } +func (*DefaultStreamContext) OnUpstreamClose(types.PeerType) {} +func (*DefaultStreamContext) OnStreamDone() {} // impl HttpContext -func (d DefaultContext) OnHttpRequestHeaders(int, bool) types.Action { return types.ActionContinue } -func (d DefaultContext) OnHttpRequestBody(int, bool) types.Action { return types.ActionContinue } -func (d DefaultContext) OnHttpRequestTrailers(int) types.Action { return types.ActionContinue } -func (d DefaultContext) OnHttpResponseHeaders(int, bool) types.Action { return types.ActionContinue } -func (d DefaultContext) OnHttpResponseBody(int, bool) types.Action { return types.ActionContinue } -func (d DefaultContext) OnHttpResponseTrailers(int) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpRequestHeaders(int, bool) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpRequestBody(int, bool) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpRequestTrailers(int) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpResponseHeaders(int, bool) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpResponseBody(int, bool) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpResponseTrailers(int) types.Action { return types.ActionContinue } +func (*DefaultHttpContext) OnHttpStreamDone() {} diff --git a/proxywasm/hostcall.go b/proxywasm/hostcall.go index 91b2fcd6..ed48f2dd 100644 --- a/proxywasm/hostcall.go +++ b/proxywasm/hostcall.go @@ -15,25 +15,23 @@ package proxywasm import ( - "strconv" - "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) // wrappers on the rawhostcall package -func HostCallGetPluginConfiguration(dataSize int) ([]byte, error) { +func GetPluginConfiguration(dataSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypePluginConfiguration, 0, dataSize) return ret, types.StatusToError(st) } -func HostCallGetVMConfiguration(dataSize int) ([]byte, error) { +func GetVMConfiguration(dataSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypeVMConfiguration, 0, dataSize) return ret, types.StatusToError(st) } -func HostCallSendHttpResponse(statusCode uint32, headers [][2]string, body string) types.Status { +func SendHttpResponse(statusCode uint32, headers [][2]string, body string) types.Status { shs := SerializeMap(headers) hp := &shs[0] hl := len(shs) @@ -42,22 +40,19 @@ func HostCallSendHttpResponse(statusCode uint32, headers [][2]string, body strin ) } -func hostCallSetEffectiveContext(contextID uint32) types.Status { - return rawhostcall.ProxySetEffectiveContext(contextID) -} - -func HostCallSetTickPeriodMilliSeconds(millSec uint32) error { +func SetTickPeriodMilliSeconds(millSec uint32) error { return types.StatusToError(rawhostcall.ProxySetTickPeriodMilliseconds(millSec)) } -func HostCallGetCurrentTime() int64 { +func GetCurrentTime() int64 { var t int64 rawhostcall.ProxyGetCurrentTimeNanoseconds(&t) return t } -func HostCallDispatchHttpCall(upstream string, - headers [][2]string, body string, trailers [][2]string, timeoutMillisecond uint32) (uint32, error) { +func DispatchHttpCall(upstream string, + headers [][2]string, body string, trailers [][2]string, + timeoutMillisecond uint32, callBack HttpCalloutCallBack) (calloutID uint32, err error) { shs := SerializeMap(headers) hp := &shs[0] hl := len(shs) @@ -66,176 +61,165 @@ func HostCallDispatchHttpCall(upstream string, tp := &sts[0] tl := len(sts) - var calloutID uint32 - u := stringBytePtr(upstream) switch st := rawhostcall.ProxyHttpCall(u, len(upstream), hp, hl, stringBytePtr(body), len(body), tp, tl, timeoutMillisecond, &calloutID); st { case types.StatusOK: - currentState.registerCallout(calloutID) + currentState.registerHttpCallOut(calloutID, callBack) return calloutID, nil default: return 0, types.StatusToError(st) } } -func HostCallGetHttpCallResponseHeaders() ([][2]string, error) { +func GetHttpCallResponseHeaders() ([][2]string, error) { ret, st := getMap(types.MapTypeHttpCallResponseHeaders) return ret, types.StatusToError(st) } -func HostCallGetHttpCallResponseBody(start, maxSize int) ([]byte, error) { +func GetHttpCallResponseBody(start, maxSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypeHttpCallResponseBody, start, maxSize) return ret, types.StatusToError(st) } -func HostCallGetHttpCallResponseTrailers() ([][2]string, error) { +func GetHttpCallResponseTrailers() ([][2]string, error) { ret, st := getMap(types.MapTypeHttpCallResponseTrailers) return ret, types.StatusToError(st) } -func HostCallDone() { - switch st := rawhostcall.ProxyDone(); st { - case types.StatusOK: - return - default: - panic("unexpected status on proxy_done: " + strconv.FormatUint(uint64(st), 10)) - } -} - -func HostCallGetDownStreamData(start, maxSize int) ([]byte, error) { +func GetDownStreamData(start, maxSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypeDownstreamData, start, maxSize) return ret, types.StatusToError(st) } -func HostCallGetUpstreamData(start, maxSize int) ([]byte, error) { +func GetUpstreamData(start, maxSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypeUpstreamData, start, maxSize) return ret, types.StatusToError(st) } -func HostCallGetHttpRequestHeaders() ([][2]string, error) { +func GetHttpRequestHeaders() ([][2]string, error) { ret, st := getMap(types.MapTypeHttpRequestHeaders) return ret, types.StatusToError(st) } -func HostCallSetHttpRequestHeaders(headers [][2]string) error { +func SetHttpRequestHeaders(headers [][2]string) error { return types.StatusToError(setMap(types.MapTypeHttpRequestHeaders, headers)) } -func HostCallGetHttpRequestHeader(key string) (string, error) { +func GetHttpRequestHeader(key string) (string, error) { ret, st := getMapValue(types.MapTypeHttpRequestHeaders, key) return ret, types.StatusToError(st) } -func HostCallRemoveHttpRequestHeader(key string) error { +func RemoveHttpRequestHeader(key string) error { return types.StatusToError(removeMapValue(types.MapTypeHttpRequestHeaders, key)) } -func HostCallSetHttpRequestHeader(key, value string) error { +func SetHttpRequestHeader(key, value string) error { return types.StatusToError(setMapValue(types.MapTypeHttpRequestHeaders, key, value)) } -func HostCallAddHttpRequestHeader(key, value string) error { +func AddHttpRequestHeader(key, value string) error { return types.StatusToError(addMapValue(types.MapTypeHttpRequestHeaders, key, value)) } -func HostCallGetHttpRequestBody(start, maxSize int) ([]byte, error) { +func GetHttpRequestBody(start, maxSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypeHttpRequestBody, start, maxSize) return ret, types.StatusToError(st) } -func HostCallGetHttpRequestTrailers() ([][2]string, error) { +func GetHttpRequestTrailers() ([][2]string, error) { ret, st := getMap(types.MapTypeHttpRequestTrailers) return ret, types.StatusToError(st) } -func HostCallSetHttpRequestTrailers(headers [][2]string) error { +func SetHttpRequestTrailers(headers [][2]string) error { return types.StatusToError(setMap(types.MapTypeHttpRequestTrailers, headers)) } -func HostCallGetHttpRequestTrailer(key string) (string, error) { +func GetHttpRequestTrailer(key string) (string, error) { ret, st := getMapValue(types.MapTypeHttpRequestTrailers, key) return ret, types.StatusToError(st) } -func HostCallRemoveHttpRequestTrailer(key string) error { +func RemoveHttpRequestTrailer(key string) error { return types.StatusToError(removeMapValue(types.MapTypeHttpRequestTrailers, key)) } -func HostCallSetHttpRequestTrailer(key, value string) error { +func SetHttpRequestTrailer(key, value string) error { return types.StatusToError(setMapValue(types.MapTypeHttpRequestTrailers, key, value)) } -func HostCallAddHttpRequestTrailer(key, value string) error { +func AddHttpRequestTrailer(key, value string) error { return types.StatusToError(addMapValue(types.MapTypeHttpRequestTrailers, key, value)) } -func HostCallResumeHttpRequest() error { +func ResumeHttpRequest() error { return types.StatusToError(rawhostcall.ProxyContinueStream(types.StreamTypeRequest)) } -func HostCallGetHttpResponseHeaders() ([][2]string, error) { +func GetHttpResponseHeaders() ([][2]string, error) { ret, st := getMap(types.MapTypeHttpResponseHeaders) return ret, types.StatusToError(st) } -func HostCallSetHttpResponseHeaders(headers [][2]string) error { +func SetHttpResponseHeaders(headers [][2]string) error { return types.StatusToError(setMap(types.MapTypeHttpResponseHeaders, headers)) } -func HostCallGetHttpResponseHeader(key string) (string, error) { +func GetHttpResponseHeader(key string) (string, error) { ret, st := getMapValue(types.MapTypeHttpResponseHeaders, key) return ret, types.StatusToError(st) } -func HostCallRemoveHttpResponseHeader(key string) error { +func RemoveHttpResponseHeader(key string) error { return types.StatusToError(removeMapValue(types.MapTypeHttpResponseHeaders, key)) } -func HostCallSetHttpResponseHeader(key, value string) error { +func SetHttpResponseHeader(key, value string) error { return types.StatusToError(setMapValue(types.MapTypeHttpResponseHeaders, key, value)) } -func HostCallAddHttpResponseHeader(key, value string) error { +func AddHttpResponseHeader(key, value string) error { return types.StatusToError(addMapValue(types.MapTypeHttpResponseHeaders, key, value)) } -func HostCallGetHttpResponseBody(start, maxSize int) ([]byte, error) { +func GetHttpResponseBody(start, maxSize int) ([]byte, error) { ret, st := getBuffer(types.BufferTypeHttpResponseBody, start, maxSize) return ret, types.StatusToError(st) } -func HostCallGetHttpResponseTrailers() ([][2]string, error) { +func GetHttpResponseTrailers() ([][2]string, error) { ret, st := getMap(types.MapTypeHttpResponseTrailers) return ret, types.StatusToError(st) } -func HostCallSetHttpResponseTrailers(headers [][2]string) error { +func SetHttpResponseTrailers(headers [][2]string) error { return types.StatusToError(setMap(types.MapTypeHttpResponseTrailers, headers)) } -func HostCallGetHttpResponseTrailer(key string) (string, error) { +func GetHttpResponseTrailer(key string) (string, error) { ret, st := getMapValue(types.MapTypeHttpResponseTrailers, key) return ret, types.StatusToError(st) } -func HostCallRemoveHttpResponseTrailer(key string) error { +func RemoveHttpResponseTrailer(key string) error { return types.StatusToError(removeMapValue(types.MapTypeHttpResponseTrailers, key)) } -func HostCallSetHttpResponseTrailer(key, value string) error { +func SetHttpResponseTrailer(key, value string) error { return types.StatusToError(setMapValue(types.MapTypeHttpResponseTrailers, key, value)) } -func HostCallAddHttpResponseTrailer(key, value string) error { +func AddHttpResponseTrailer(key, value string) error { return types.StatusToError(addMapValue(types.MapTypeHttpResponseTrailers, key, value)) } -func HostCallResumeHttpResponse() error { +func ResumeHttpResponse() error { return types.StatusToError(rawhostcall.ProxyContinueStream(types.StreamTypeResponse)) } -func HostCallRegisterSharedQueue(name string) (uint32, error) { +func RegisterSharedQueue(name string) (uint32, error) { var queueID uint32 ptr := stringBytePtr(name) st := rawhostcall.ProxyRegisterSharedQueue(ptr, len(name), &queueID) @@ -243,14 +227,14 @@ func HostCallRegisterSharedQueue(name string) (uint32, error) { } // TODO: not sure if the ABI is correct -func HostCallResolveSharedQueue(vmID, queueName string) (uint32, error) { +func ResolveSharedQueue(vmID, queueName string) (uint32, error) { var ret uint32 st := rawhostcall.ProxyResolveSharedQueue(stringBytePtr(vmID), len(vmID), stringBytePtr(queueName), len(queueName), &ret) return ret, types.StatusToError(st) } -func HostCallDequeueSharedQueue(queueID uint32) ([]byte, error) { +func DequeueSharedQueue(queueID uint32) ([]byte, error) { var raw *byte var size int st := rawhostcall.ProxyDequeueSharedQueue(queueID, &raw, &size) @@ -260,11 +244,11 @@ func HostCallDequeueSharedQueue(queueID uint32) ([]byte, error) { return RawBytePtrToByteSlice(raw, size), nil } -func HostCallEnqueueSharedQueue(queueID uint32, data []byte) error { +func EnqueueSharedQueue(queueID uint32, data []byte) error { return types.StatusToError(rawhostcall.ProxyEnqueueSharedQueue(queueID, &data[0], len(data))) } -func HostCallGetSharedData(key string) (value []byte, cas uint32, err error) { +func GetSharedData(key string) (value []byte, cas uint32, err error) { var raw *byte var size int @@ -275,13 +259,13 @@ func HostCallGetSharedData(key string) (value []byte, cas uint32, err error) { return RawBytePtrToByteSlice(raw, size), cas, nil } -func HostCallSetSharedData(key string, data []byte, cas uint32) error { +func SetSharedData(key string, data []byte, cas uint32) error { st := rawhostcall.ProxySetSharedData(stringBytePtr(key), len(key), &data[0], len(data), cas) return types.StatusToError(st) } -func HostCallGetProperty(path []string) ([]byte, error) { +func GetProperty(path []string) ([]byte, error) { var ret *byte var retSize int raw := SerializePropertyPath(path) @@ -295,7 +279,7 @@ func HostCallGetProperty(path []string) ([]byte, error) { } -func HostCallSetProperty(path string, data []byte) error { +func SetProperty(path string, data []byte) error { return types.StatusToError(rawhostcall.ProxySetProperty( stringBytePtr(path), len(path), &data[0], len(data), )) diff --git a/proxywasm/hostcall_lifecycle.go b/proxywasm/hostcall_lifecycle.go new file mode 100644 index 00000000..ac06b516 --- /dev/null +++ b/proxywasm/hostcall_lifecycle.go @@ -0,0 +1,11 @@ +package proxywasm + +import "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" + +func SetEffectiveContext(contextID uint32) { + rawhostcall.ProxySetEffectiveContext(contextID) +} + +func FinishContext() { + rawhostcall.ProxyDone() +} diff --git a/proxywasm/hostcall_metric.go b/proxywasm/hostcall_metric.go index a7c0b7a4..4da35bb1 100644 --- a/proxywasm/hostcall_metric.go +++ b/proxywasm/hostcall_metric.go @@ -27,69 +27,100 @@ type ( // counter -func DefineCounterMetric(name string) (MetricCounter, error) { +func DefineCounterMetric(name string) MetricCounter { var id uint32 ptr := stringBytePtr(name) st := rawhostcall.ProxyDefineMetric(types.MetricTypeCounter, ptr, len(name), &id) - return MetricCounter(id), types.StatusToError(st) + if err := types.StatusToError(st); err != nil { + LogCriticalf("define metric of name %s: %v", name, types.StatusToError(st)) + } + return MetricCounter(id) } func (m MetricCounter) ID() uint32 { return uint32(m) } -func (m MetricCounter) Get() (uint64, error) { +func (m MetricCounter) Get() uint64 { var val uint64 st := rawhostcall.ProxyGetMetric(m.ID(), &val) - return val, types.StatusToError(st) + if err := types.StatusToError(st); err != nil { + LogCriticalf("get metric of %d: %v", m.ID(), types.StatusToError(st)) + panic("") // abort + } + return val } -func (m MetricCounter) Increment(offset uint64) error { - return types.StatusToError(rawhostcall.ProxyIncrementMetric(m.ID(), int64(offset))) +func (m MetricCounter) Increment(offset uint64) { + if err := types.StatusToError(rawhostcall.ProxyIncrementMetric(m.ID(), int64(offset))); err != nil { + LogCriticalf("increment %d by %d: %v", m.ID(), offset, err) + panic("") // abort + } } // gauge -func DefineGaugeMetric(name string) (MetricGauge, error) { +func DefineGaugeMetric(name string) MetricGauge { var id uint32 ptr := stringBytePtr(name) st := rawhostcall.ProxyDefineMetric(types.MetricTypeGauge, ptr, len(name), &id) - return MetricGauge(id), types.StatusToError(st) + if err := types.StatusToError(st); err != nil { + LogCriticalf("error define metric of name %s: %v", name, types.StatusToError(st)) + panic("") // abort + } + return MetricGauge(id) } func (m MetricGauge) ID() uint32 { return uint32(m) } -func (m MetricGauge) Get() (int64, error) { +func (m MetricGauge) Get() int64 { var val uint64 - st := rawhostcall.ProxyGetMetric(m.ID(), &val) - return int64(val), types.StatusToError(st) + if err := types.StatusToError(rawhostcall.ProxyGetMetric(m.ID(), &val)); err != nil { + LogCriticalf("get metric of %d: %v", m.ID(), err) + panic("") // abort + } + return int64(val) } -func (m MetricGauge) Add(offset int64) error { - return types.StatusToError(rawhostcall.ProxyIncrementMetric(m.ID(), offset)) +func (m MetricGauge) Add(offset int64) { + if err := types.StatusToError(rawhostcall.ProxyIncrementMetric(m.ID(), offset)); err != nil { + LogCriticalf("error adding %d by %d: %v", m.ID(), offset, err) + panic("") // abort + } } // histogram -func DefineHistogramMetric(name string) (MetricHistogram, error) { +func DefineHistogramMetric(name string) MetricHistogram { var id uint32 ptr := stringBytePtr(name) st := rawhostcall.ProxyDefineMetric(types.MetricTypeHistogram, ptr, len(name), &id) - return MetricHistogram(id), types.StatusToError(st) + if err := types.StatusToError(st); err != nil { + LogCriticalf("error define metric of name %s: %v", name, types.StatusToError(st)) + panic("") // abort + } + return MetricHistogram(id) } func (m MetricHistogram) ID() uint32 { return uint32(m) } -func (m MetricHistogram) Get() (uint64, error) { +func (m MetricHistogram) Get() uint64 { var val uint64 st := rawhostcall.ProxyGetMetric(m.ID(), &val) - return val, types.StatusToError(st) + if err := types.StatusToError(st); err != nil { + LogCriticalf("get metric of %d: %v", m.ID(), types.StatusToError(st)) + panic("") // abort + } + return val } -func (m MetricHistogram) Record(value uint64) error { - return types.StatusToError(rawhostcall.ProxyRecordMetric(m.ID(), value)) +func (m MetricHistogram) Record(value uint64) { + if err := types.StatusToError(rawhostcall.ProxyRecordMetric(m.ID(), value)); err != nil { + LogCriticalf("error adding %d: %v", m.ID(), err) + panic("") // abort + } } diff --git a/proxywasm/hostcall_metric_test.go b/proxywasm/hostcall_metric_test.go index 789d0c85..e6a4c7e4 100644 --- a/proxywasm/hostcall_metric_test.go +++ b/proxywasm/hostcall_metric_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/rawhostcall" "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types" ) @@ -78,16 +77,13 @@ func TestHostCall_Metric(t *testing.T) { } { t.Run(c.name, func(t *testing.T) { // define metric - m, err := DefineCounterMetric(c.name) - require.NoError(t, err) + m := DefineCounterMetric(c.name) // increment - require.NoError(t, m.Increment(c.offset)) + m.Increment(c.offset) // get - value, err := m.Get() - require.NoError(t, err) - assert.Equal(t, c.offset, value) + assert.Equal(t, c.offset, m.Get()) }) } }) @@ -101,16 +97,13 @@ func TestHostCall_Metric(t *testing.T) { } { t.Run(c.name, func(t *testing.T) { // define metric - m, err := DefineGaugeMetric(c.name) - require.NoError(t, err) + m := DefineGaugeMetric(c.name) // increment - require.NoError(t, m.Add(c.offset)) + m.Add(c.offset) // get - value, err := m.Get() - require.NoError(t, err) - assert.Equal(t, c.offset, value) + assert.Equal(t, c.offset, m.Get()) }) } }) @@ -124,16 +117,13 @@ func TestHostCall_Metric(t *testing.T) { } { t.Run(c.name, func(t *testing.T) { // define metric - m, err := DefineHistogramMetric(c.name) - require.NoError(t, err) + m := DefineHistogramMetric(c.name) // record - require.NoError(t, m.Record(c.value)) + m.Record(c.value) // get - value, err := m.Get() - require.NoError(t, err) - assert.Equal(t, c.value, value) + assert.Equal(t, c.value, m.Get()) }) } }) diff --git a/proxywasm/types/types.go b/proxywasm/types/types.go index f07325ae..478e2544 100644 --- a/proxywasm/types/types.go +++ b/proxywasm/types/types.go @@ -41,6 +41,25 @@ const ( LogLevelMax LogLevel = 6 ) +func (l LogLevel) String() string { + switch l { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelCritical: + return "critical" + default: + panic("invalid log level") + } +} + type Status uint32 const ( diff --git a/proxywasm/vmstate.go b/proxywasm/vmstate.go index f80ad46c..f4af67ec 100644 --- a/proxywasm/vmstate.go +++ b/proxywasm/vmstate.go @@ -14,22 +14,35 @@ package proxywasm -var currentState = &state{ - rootContexts: make(map[uint32]RootContext), - httpContexts: make(map[uint32]HttpContext), - streamContexts: make(map[uint32]StreamContext), - callOuts: make(map[uint32]uint32), -} +type ( + HttpCalloutCallBack = func(numHeaders, bodySize, numTrailers int) + + rootContextState struct { + context RootContext + httpCallbacks map[uint32]*struct { + callback HttpCalloutCallBack + callerContextID uint32 + } + } +) type state struct { newRootContext func(contextID uint32) RootContext + rootContexts map[uint32]*rootContextState newStreamContext func(contextID uint32) StreamContext + streams map[uint32]StreamContext newHttpContext func(contextID uint32) HttpContext - rootContexts map[uint32]RootContext - httpContexts map[uint32]HttpContext - streamContexts map[uint32]StreamContext + httpStreams map[uint32]HttpContext + + contextIDToRooID map[uint32]uint32 activeContextID uint32 - callOuts map[uint32]uint32 +} + +var currentState = &state{ + rootContexts: make(map[uint32]*rootContextState), + httpStreams: make(map[uint32]HttpContext), + streams: make(map[uint32]StreamContext), + contextIDToRooID: make(map[uint32]uint32), } func SetNewRootContext(f func(contextID uint32) RootContext) { @@ -44,15 +57,22 @@ func SetNewStreamContext(f func(contextID uint32) StreamContext) { currentState.newStreamContext = f } +//go:inline func (s *state) createRootContext(contextID uint32) { var ctx RootContext if s.newRootContext == nil { - ctx = &DefaultContext{} + ctx = &DefaultRootContext{} } else { ctx = s.newRootContext(contextID) } - s.rootContexts[contextID] = ctx + s.rootContexts[contextID] = &rootContextState{ + context: ctx, + httpCallbacks: map[uint32]*struct { + callback HttpCalloutCallBack + callerContextID uint32 + }{}, + } } func (s *state) createStreamContext(contextID uint32, rootContextID uint32) { @@ -60,11 +80,13 @@ func (s *state) createStreamContext(contextID uint32, rootContextID uint32) { panic("invalid root context id") } - if _, ok := s.streamContexts[contextID]; ok { + if _, ok := s.streams[contextID]; ok { panic("context id duplicated") } - s.streamContexts[contextID] = s.newStreamContext(contextID) + ctx := s.newStreamContext(contextID) + s.contextIDToRooID[contextID] = rootContextID + s.streams[contextID] = ctx } func (s *state) createHttpContext(contextID uint32, rootContextID uint32) { @@ -72,22 +94,24 @@ func (s *state) createHttpContext(contextID uint32, rootContextID uint32) { panic("invalid root context id") } - if _, ok := s.httpContexts[contextID]; ok { + if _, ok := s.httpStreams[contextID]; ok { panic("context id duplicated") } - s.httpContexts[contextID] = s.newHttpContext(contextID) + ctx := s.newHttpContext(contextID) + s.contextIDToRooID[contextID] = rootContextID + s.httpStreams[contextID] = ctx } -func (s *state) registerCallout(calloutID uint32) { - if _, ok := s.callOuts[calloutID]; ok { - panic("duplicated calloutID") - } - - s.callOuts[calloutID] = s.activeContextID +func (s *state) registerHttpCallOut(calloutID uint32, callback HttpCalloutCallBack) { + r := s.rootContexts[s.contextIDToRooID[s.activeContextID]] + r.httpCallbacks[calloutID] = &struct { + callback HttpCalloutCallBack + callerContextID uint32 + }{callback: callback, callerContextID: s.activeContextID} } +//go:inline func (s *state) setActiveContextID(contextID uint32) { - // TODO: should we do this inline (possibly for performance)? s.activeContextID = contextID } diff --git a/proxywasm/vmstate_test.go b/proxywasm/vmstate_test.go index cd42d8e1..8645be81 100644 --- a/proxywasm/vmstate_test.go +++ b/proxywasm/vmstate_test.go @@ -54,9 +54,9 @@ func TestSetNewStreamContext(t *testing.T) { func TestState_createRootContext(t *testing.T) { t.Run("newRootContext exists", func(t *testing.T) { - type rc struct{ DefaultContext } + type rc struct{ DefaultRootContext } s := &state{ - rootContexts: map[uint32]RootContext{}, + rootContexts: map[uint32]*rootContextState{}, newRootContext: func(contextID uint32) RootContext { return &rc{} }, } @@ -66,67 +66,55 @@ func TestState_createRootContext(t *testing.T) { }) t.Run("non exists", func(t *testing.T) { - s := &state{rootContexts: map[uint32]RootContext{}} + s := &state{rootContexts: map[uint32]*rootContextState{}} var cid uint32 = 100 s.createRootContext(cid) c, ok := s.rootContexts[cid] require.True(t, ok) - _, ok = c.(*DefaultContext) + _, ok = c.context.(*DefaultRootContext) assert.True(t, ok) }) } func TestState_createStreamContext(t *testing.T) { - type sc struct{ DefaultContext } + type sc struct{ DefaultStreamContext } var ( cid uint32 = 100 rid uint32 = 10 ) s := &state{ - rootContexts: map[uint32]RootContext{rid: nil}, - streamContexts: map[uint32]StreamContext{}, + rootContexts: map[uint32]*rootContextState{rid: nil}, + streams: map[uint32]StreamContext{}, newStreamContext: func(contextID uint32) StreamContext { return &sc{} }, + contextIDToRooID: map[uint32]uint32{}, } s.createStreamContext(cid, rid) - c, ok := s.streamContexts[cid] + c, ok := s.streams[cid] require.True(t, ok) _, ok = c.(*sc) assert.True(t, ok) } func TestState_createHttpContext(t *testing.T) { - type hc struct{ DefaultContext } + type hc struct{ DefaultHttpContext } var ( cid uint32 = 100 rid uint32 = 10 ) s := &state{ - rootContexts: map[uint32]RootContext{rid: nil}, - httpContexts: map[uint32]HttpContext{}, - newHttpContext: func(contextID uint32) HttpContext { return &hc{} }, + rootContexts: map[uint32]*rootContextState{rid: nil}, + httpStreams: map[uint32]HttpContext{}, + newHttpContext: func(contextID uint32) HttpContext { return &hc{} }, + contextIDToRooID: map[uint32]uint32{}, } s.createHttpContext(cid, rid) - c, ok := s.httpContexts[cid] + c, ok := s.httpStreams[cid] require.True(t, ok) _, ok = c.(*hc) assert.True(t, ok) } - -func TestState_registerCallout(t *testing.T) { - var calloutID uint32 = 100 - s := &state{callOuts: map[uint32]uint32{}, activeContextID: 200} - s.registerCallout(calloutID) - assert.Equal(t, s.callOuts[calloutID], s.activeContextID) -} - -func TestState_setActiveContextID(t *testing.T) { - s := state{} - var cID uint32 = 100 - s.setActiveContextID(cID) - assert.Equal(t, s.activeContextID, cID) -} diff --git a/proxytest/root_test.go b/proxywasm/vmstate_test_export.go similarity index 56% rename from proxytest/root_test.go rename to proxywasm/vmstate_test_export.go index cd0d75c4..4cc0aed0 100644 --- a/proxytest/root_test.go +++ b/proxywasm/vmstate_test_export.go @@ -12,6 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package proxytest +// +build proxytest -// TODO: +package proxywasm + +func VMStateReset() { + // (@mathetake) I assume that the currentState be protected by lock on hostMux + currentState = &state{ + rootContexts: make(map[uint32]*rootContextState), + httpStreams: make(map[uint32]HttpContext), + streams: make(map[uint32]StreamContext), + contextIDToRooID: make(map[uint32]uint32), + } +} + +func VMStateGetActiveContextID() uint32 { + return currentState.activeContextID +}