diff --git a/context-aware-encryption/README.md b/context-aware-encryption/README.md new file mode 100644 index 00000000..43d5fa4d --- /dev/null +++ b/context-aware-encryption/README.md @@ -0,0 +1,25 @@ +### Steps to run this sample: +1) Run a [Temporal service](https://github.com/temporalio/samples-go/tree/main/#how-to-use). +2) Run the following command to start the remote codec server +``` +go run ./codec-server +``` +3) Run the following command to start the worker +``` +go run worker/main.go +``` +4) Run the following command to start the example +``` +go run starter/main.go +``` +5) Run the following command and see the payloads cannot be decoded +``` +tctl workflow show --wid encryption_workflowID +``` +6) Run the following command and see the decoded payloads +``` +tctl --codec_endpoint 'http://localhost:8081/' workflow show --wid encryption_workflowID +``` + +Note: The codec server provided in this sample does not support decoding payloads for the Temporal Web UI, only tctl. +Please see the [codec-server](../codec-server/) sample for a more complete example of a codec server which provides UI decoding and oauth. diff --git a/context-aware-encryption/app/main.go b/context-aware-encryption/app/main.go new file mode 100644 index 00000000..f012daca --- /dev/null +++ b/context-aware-encryption/app/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + contextawareencryption "github.com/temporalio/samples-go/context-aware-encryption" + sdkclient "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "golang.org/x/sync/errgroup" + "log" + "os" + "os/signal" + "syscall" + "time" +) + +type startable interface { + Start(context.Context) error + Shutdown(context.Context) +} + +func main() { + ctx, done := context.WithCancel(context.Background()) + + c := contextawareencryption.MustGetDefaultTemporalClient(ctx, nil) + defer c.Close() + g, ctx := errgroup.WithContext(ctx) + + // set up signal listener + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(quit) + + startables := []startable{ + &Worker{tclient: c}, + &App{tclient: c, maxCount: len(contextawareencryption.TenantKeysByOrganization)}, + } + + for _, s := range startables { + var current = s + g.Go(func() error { + if err := current.Start(ctx); err != nil { + return err + } + return nil + }) + } + + select { + case <-quit: + break + case <-ctx.Done(): + break + } + + // shutdown the things + done() + // limit how long we'll wait for + timeoutCtx, timeoutCancel := context.WithTimeout( + context.Background(), + 10*time.Second, + ) + defer timeoutCancel() + + for _, s := range startables { + s.Shutdown(timeoutCtx) + } + + // wait for shutdown + if err := g.Wait(); err != nil { + panic("shutdown was not clean" + err.Error()) + } +} + +type Worker struct { + tclient sdkclient.Client + worker worker.Worker +} + +func (w *Worker) Start(ctx context.Context) error { + w.worker = worker.New(w.tclient, "encryption", worker.Options{}) + + w.worker.RegisterWorkflow(contextawareencryption.TenantWorkflow) + w.worker.RegisterActivity(contextawareencryption.TenantActivity) + + return w.worker.Run(worker.InterruptCh()) +} +func (w *Worker) Shutdown(ctx context.Context) { + w.worker.Stop() +} + +type App struct { + tclient sdkclient.Client + maxCount int +} + +func (a *App) Shutdown(ctx context.Context) { + +} +func (a *App) Start(ctx context.Context) error { + if a.maxCount == 0 { + return fmt.Errorf("You must at least one run Workflow") + } + dt := time.Now().UTC().String() + count := 0 + for tenant, keyId := range contextawareencryption.TenantKeysByOrganization { + wid := fmt.Sprintf("tenant_%s-%s", tenant, dt) + workflowOptions := sdkclient.StartWorkflowOptions{ + ID: wid, + TaskQueue: "encryption", + } + + // If you are using a ContextPropagator and varying keys per workflow you need to set + // the KeyID to use for this workflow in the context: + fmt.Println(fmt.Sprintf("Setting encryption key for '%s' with value '%s'", tenant, keyId)) + ctx = context.WithValue(ctx, + contextawareencryption.PropagateKey, + contextawareencryption.CryptContext{KeyID: keyId}) + + // The workflow input tenant will be encrypted by the DataConverter before being sent to Temporal + we, err := a.tclient.ExecuteWorkflow( + ctx, + workflowOptions, + contextawareencryption.TenantWorkflow, + "workflowargument for "+tenant, + ) + if err != nil { + log.Fatalln("Unable to execute workflow", err) + } + + log.Println("Started workflow", "WorkflowID", we.GetID(), "RunID", we.GetRunID()) + + // Synchronously wait for the workflow completion. + var result string + err = we.Get(context.Background(), &result) + if err != nil { + log.Fatalln("Unable get workflow result", err) + } + log.Println("TenantWorkflow result:", result) + count++ + if count >= a.maxCount { + break + } + } + return nil +} diff --git a/context-aware-encryption/codec-server/main.go b/context-aware-encryption/codec-server/main.go new file mode 100644 index 00000000..58ad0a4c --- /dev/null +++ b/context-aware-encryption/codec-server/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "flag" + contextawareencryption "github.com/temporalio/samples-go/context-aware-encryption" + "log" + "net/http" + "os" + "os/signal" + "strconv" + + "go.temporal.io/sdk/converter" +) + +var portFlag int + +func init() { + flag.IntVar(&portFlag, "port", 8081, "Port to listen on") +} + +func main() { + flag.Parse() + + // This example codec server does not support varying config per namespace, + // decoding for the Temporal Web UI or oauth. + // For a more complete example of a codec server please see the codec-server sample at: + // ../../codec-server. + handler := converter.NewPayloadCodecHTTPHandler(&contextawareencryption.Codec{}, converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true})) + + srv := &http.Server{ + Addr: "0.0.0.0:" + strconv.Itoa(portFlag), + Handler: handler, + } + + errCh := make(chan error, 1) + go func() { errCh <- srv.ListenAndServe() }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + + select { + case <-sigCh: + _ = srv.Close() + case err := <-errCh: + log.Fatal(err) + } +} diff --git a/context-aware-encryption/codec.go b/context-aware-encryption/codec.go new file mode 100644 index 00000000..3f88a6d0 --- /dev/null +++ b/context-aware-encryption/codec.go @@ -0,0 +1,97 @@ +package contextawareencryption + +import ( + "fmt" + commonpb "go.temporal.io/api/common/v1" + "go.temporal.io/sdk/converter" +) + +// Codec implements PayloadCodec using AES Crypt. +type Codec struct { + KeyID string + Tenant string +} + +func (e *Codec) getKey(keyID string) (key []byte) { + // Key must be fetched from secure storage in production (such as a KMS). + // For testing here we just hard code a key. + result := keyID + "test-key-test-key-test-key-test!" + return []byte(result[0:32]) +} + +// Encode implements converter.PayloadCodec.Encode. +func (e *Codec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { + result := make([]*commonpb.Payload, len(payloads)) + for i, p := range payloads { + origBytes, err := p.Marshal() + if err != nil { + return payloads, err + } + + key := e.getKey(e.KeyID) + fmt.Println("codec.Encode using tenant/key:", e.KeyID) + + b, err := encrypt(origBytes, key) + if err != nil { + return payloads, err + } + + result[i] = &commonpb.Payload{ + Metadata: map[string][]byte{ + converter.MetadataEncoding: []byte(MetadataEncodingEncrypted), + MetadataEncryptionKeyID: []byte(e.KeyID), + MetadataTenant: []byte(e.Tenant), + }, + Data: b, + } + } + + return result, nil +} + +// Decode implements converter.PayloadCodec.Decode. +func (e *Codec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) { + result := make([]*commonpb.Payload, len(payloads)) + for i, p := range payloads { + // Only if it's encrypted + if string(p.Metadata[converter.MetadataEncoding]) != MetadataEncodingEncrypted { + result[i] = p + continue + } + + keyID, ok := p.Metadata[MetadataEncryptionKeyID] + if !ok { + return payloads, fmt.Errorf("no encryption key id") + } + tenant, ok := p.Metadata[MetadataTenant] + if !ok { + return payloads, fmt.Errorf("no tenant id") + } + key := e.getKey(string(keyID)) + fmt.Println("codec.Decode using tenant/key:", string(tenant), string(key)) + + b, err := decrypt(p.Data, key) + if err != nil { + return payloads, err + } + + result[i] = &commonpb.Payload{} + err = result[i].Unmarshal(b) + if err != nil { + return payloads, err + } + } + + return result, nil +} + +func GetTenantCodec(key string) *Codec { + availableCodecs := map[string]*Codec{} + for tenant, keyId := range TenantKeysByOrganization { + availableCodecs[keyId] = &Codec{Tenant: tenant, KeyID: keyId} + } + if c, exists := availableCodecs[key]; exists { + return c + } + return &Codec{Tenant: "UNKNOWN TENANT", KeyID: key} +} diff --git a/context-aware-encryption/crypt.go b/context-aware-encryption/crypt.go new file mode 100644 index 00000000..65575d49 --- /dev/null +++ b/context-aware-encryption/crypt.go @@ -0,0 +1,48 @@ +package contextawareencryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +func encrypt(plainData []byte, key []byte) ([]byte, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, plainData, nil), nil +} + +func decrypt(encryptedData []byte, key []byte) ([]byte, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(encryptedData) < nonceSize { + return nil, fmt.Errorf("ciphertext too short: %v", encryptedData) + } + + nonce, encryptedData := encryptedData[:nonceSize], encryptedData[nonceSize:] + return gcm.Open(nil, nonce, encryptedData, nil) +} diff --git a/context-aware-encryption/data_converter.go b/context-aware-encryption/data_converter.go new file mode 100644 index 00000000..8bc1a163 --- /dev/null +++ b/context-aware-encryption/data_converter.go @@ -0,0 +1,85 @@ +package contextawareencryption + +import ( + "context" + "fmt" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/workflow" +) + +const ( + // MetadataEncodingEncrypted is "binary/encrypted" + MetadataEncodingEncrypted = "binary/encrypted" + + // MetadataEncryptionKeyID is "encryption-key-id" + MetadataEncryptionKeyID = "encryption-key-id" + MetadataTenant = "tenant" +) + +type DataConverterOptions struct { + KeyID string + // Enable ZLib compression before encryption. + Compress bool +} + +type DataConverter struct { + // Until EncodingDataConverter supports workflow.ContextAware we'll store parent here. + parent converter.DataConverter + converter.DataConverter + options DataConverterOptions +} + +// TODO: Implement workflow.ContextAware in CodecDataConverter +// Note that you only need to implement this function if you need to vary the encryption KeyID per workflow. +func (dc *DataConverter) WithWorkflowContext(ctx workflow.Context) converter.DataConverter { + if val, ok := ctx.Value(PropagateKey).(CryptContext); ok { + parent := dc.parent + if parentWithContext, ok := parent.(workflow.ContextAware); ok { + parent = parentWithContext.WithWorkflowContext(ctx) + } + + options := dc.options + options.KeyID = val.KeyID + fmt.Println("dataConverter.WithWorkflowContext forwarding key:", val.KeyID) + return NewEncryptionDataConverter(parent, options) + } + + return dc +} + +// TODO: Implement workflow.ContextAware in EncodingDataConverter +// Note that you only need to implement this function if you need to vary the encryption KeyID per workflow. +func (dc *DataConverter) WithContext(ctx context.Context) converter.DataConverter { + if val, ok := ctx.Value(PropagateKey).(CryptContext); ok { + parent := dc.parent + if parentWithContext, ok := parent.(workflow.ContextAware); ok { + parent = parentWithContext.WithContext(ctx) + } + + options := dc.options + options.KeyID = val.KeyID + fmt.Println("dataConverter.WithContext forwarding key:", val.KeyID) + + return NewEncryptionDataConverter(parent, options) + } + + return dc +} + +// NewEncryptionDataConverter creates a new instance of EncryptionDataConverter wrapping a DataConverter +func NewEncryptionDataConverter(dataConverter converter.DataConverter, options DataConverterOptions) *DataConverter { + codec := GetTenantCodec(options.KeyID) + codecs := []converter.PayloadCodec{codec} + // Enable compression if requested. + // Note that this must be done before encryption to provide any value. Encrypted data should by design not compress very well. + // This means the compression codec must come after the encryption codec here as codecs are applied last -> first. + if options.Compress { + codecs = append(codecs, converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true})) + } + + return &DataConverter{ + parent: dataConverter, + DataConverter: converter.NewCodecDataConverter(dataConverter, codecs...), + options: options, + } +} diff --git a/context-aware-encryption/data_converter_test.go b/context-aware-encryption/data_converter_test.go new file mode 100644 index 00000000..b4a3abd2 --- /dev/null +++ b/context-aware-encryption/data_converter_test.go @@ -0,0 +1,36 @@ +package contextawareencryption + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/converter" +) + +func Test_DataConverter(t *testing.T) { + defaultDc := converter.GetDefaultDataConverter() + + ctx := context.Background() + ctx = context.WithValue(ctx, PropagateKey, CryptContext{KeyID: "test"}) + + cryptDc := NewEncryptionDataConverter( + converter.GetDefaultDataConverter(), + DataConverterOptions{}, + ) + cryptDcWc := cryptDc.WithContext(ctx) + + defaultPayloads, err := defaultDc.ToPayloads("Testing") + require.NoError(t, err) + + encryptedPayloads, err := cryptDcWc.ToPayloads("Testing") + require.NoError(t, err) + + require.NotEqual(t, defaultPayloads.Payloads[0].GetData(), encryptedPayloads.Payloads[0].GetData()) + + var result string + err = cryptDc.FromPayloads(encryptedPayloads, &result) + require.NoError(t, err) + + require.Equal(t, "Testing", result) +} diff --git a/context-aware-encryption/propagator.go b/context-aware-encryption/propagator.go new file mode 100644 index 00000000..e99ea1f0 --- /dev/null +++ b/context-aware-encryption/propagator.go @@ -0,0 +1,90 @@ +package contextawareencryption + +import ( + "context" + "fmt" + + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/workflow" +) + +type ( + // contextKey is an unexported type used as key for items stored in the + // Context object + contextKey struct{} + + // propagator implements the custom context propagator + propagator struct{} + + // CryptConfig is a struct holding values + CryptContext struct { + KeyID string `json:"keyId"` + } +) + +// PropagateKey is the key used to store the value in the Context object +var PropagateKey = contextKey{} + +// propagationKey is the key used by the propagator to pass values through the +// Temporal server headers +const propagationKey = "encryption" + +// NewContextPropagator returns a context propagator that propagates a set of +// string key-value pairs across a workflow +func NewContextPropagator() workflow.ContextPropagator { + return &propagator{} +} + +// Inject injects values from context into headers for propagation +func (s *propagator) Inject(ctx context.Context, writer workflow.HeaderWriter) error { + value := ctx.Value(PropagateKey) + payload, err := converter.GetDefaultDataConverter().ToPayload(value) + if err != nil { + return err + } + fmt.Println("propagator.Inject writing to header", fmt.Sprintf("%v", value)) + writer.Set(propagationKey, payload) + return nil +} + +// InjectFromWorkflow injects values from context into headers for propagation +func (s *propagator) InjectFromWorkflow(ctx workflow.Context, writer workflow.HeaderWriter) error { + value := ctx.Value(PropagateKey) + payload, err := converter.GetDefaultDataConverter().ToPayload(value) + if err != nil { + return err + } + fmt.Println("propagator.InjectFromWorkflow writing to header", fmt.Sprintf("%v", value)) + writer.Set(propagationKey, payload) + return nil +} + +// Extract extracts values from headers and puts them into context +func (s *propagator) Extract(ctx context.Context, reader workflow.HeaderReader) (context.Context, error) { + if value, ok := reader.Get(propagationKey); ok { + var cryptContext CryptContext + if err := converter.GetDefaultDataConverter().FromPayload(value, &cryptContext); err != nil { + return ctx, nil + } + fmt.Println("propagator.Extract setting on context", fmt.Sprintf("%v", cryptContext)) + + ctx = context.WithValue(ctx, PropagateKey, cryptContext) + } + + return ctx, nil +} + +// ExtractToWorkflow extracts values from headers and puts them into context +func (s *propagator) ExtractToWorkflow(ctx workflow.Context, reader workflow.HeaderReader) (workflow.Context, error) { + if value, ok := reader.Get(propagationKey); ok { + var cryptContext CryptContext + if err := converter.GetDefaultDataConverter().FromPayload(value, &cryptContext); err != nil { + return ctx, nil + } + fmt.Println("propagator.ExtractToWorkflow setting on context", fmt.Sprintf("%v", cryptContext)) + + ctx = workflow.WithValue(ctx, PropagateKey, cryptContext) + } + + return ctx, nil +} diff --git a/context-aware-encryption/tclient.go b/context-aware-encryption/tclient.go new file mode 100644 index 00000000..e9edf899 --- /dev/null +++ b/context-aware-encryption/tclient.go @@ -0,0 +1,66 @@ +package contextawareencryption + +import ( + "context" + sdkclient "go.temporal.io/sdk/client" + "go.temporal.io/sdk/converter" + "go.temporal.io/sdk/workflow" + "sync" +) + +var once sync.Once +var defaultTemporalClient sdkclient.Client + +func GetDefaultOptions() sdkclient.Options { + // If you intend to let the dataConverter to decide encryption key for all workflows + // you can set the KeyID for the encryption encoder like so: + // + // DataConverter: encryption.NewEncryptionDataConverter( + // converter.GetDefaultDataConverter(), + // encryption.DataConverterOptions{KeyID: "test", Compress: true}, + // ), + // + // In this case you do not need to use a ContextPropagator. + // You also can implement the dataConverter to decide the encryption key + // dynamically so that it's not always the same key. + // + // If you need to let the workflow starter to decide the encryption key per workflow, + // you can instead leave the KeyID unset for the encoder and supply it via the workflow + // context as shown below. For this use case you will also need to use a + // ContextPropagator so that KeyID is also available in the context for activities. + // + // Set DataConverter to ensure that workflow inputs and results are + // encrypted/decrypted as required. + dataConverter := NewEncryptionDataConverter( + converter.GetDefaultDataConverter(), + DataConverterOptions{Compress: true}, + ) + + // Use a ContextPropagator so that the KeyID value set in the workflow context is + // also availble in the context for activities. + ctxProp := NewContextPropagator() + options := sdkclient.Options{ + DataConverter: dataConverter, + ContextPropagators: []workflow.ContextPropagator{ctxProp}, + } + return options +} +func GetTemporalClient(ctx context.Context, opts sdkclient.Options) (sdkclient.Client, error) { + c, err := sdkclient.Dial(opts) + return c, err +} +func MustGetDefaultTemporalClient(ctx context.Context, opts *sdkclient.Options) sdkclient.Client { + once.Do(func() { + if opts == nil { + o := GetDefaultOptions() + opts = &o + } + + var err error + defaultTemporalClient, err = GetTemporalClient(ctx, *opts) + if err != nil { + panic("failed to create default temporal client: " + err.Error()) + } + }) + return defaultTemporalClient +} diff --git a/context-aware-encryption/tenant.go b/context-aware-encryption/tenant.go new file mode 100644 index 00000000..7eaf559e --- /dev/null +++ b/context-aware-encryption/tenant.go @@ -0,0 +1,12 @@ +package contextawareencryption + +var ( + TenantKeyOrgTenant1 = "org/tenant-1" + TenantKeyOrgTenant2 = "org/tenant-2" + TenantKeyOrgTenant3 = "org/tenant-3" + TenantKeysByOrganization = map[string]string{ + TenantKeyOrgTenant1: "tenant-1-key", + TenantKeyOrgTenant2: "tenant-2-key", + TenantKeyOrgTenant3: "tenant-3-key", + } +) diff --git a/context-aware-encryption/workflow.go b/context-aware-encryption/workflow.go new file mode 100644 index 00000000..7f396b6b --- /dev/null +++ b/context-aware-encryption/workflow.go @@ -0,0 +1,60 @@ +package contextawareencryption + +import ( + "context" + "fmt" + "time" + + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/workflow" +) + +// TenantWorkflow is a standard workflow definition. +// Note that the TenantWorkflow and TenantActivity don't need to care that +// their inputs/results are being encrypted/decrypted. +func TenantWorkflow(ctx workflow.Context, name string) (string, error) { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 10 * time.Second, + } + ctx = workflow.WithActivityOptions(ctx, ao) + + logger := workflow.GetLogger(ctx) + //logger.Info("Encrypted Payloads workflow started", "name", name) + value, ok := workflow.Context.Value(ctx, PropagateKey).(CryptContext) + if !ok { + logger.Error("Unable to retrieve context") + } + logger.Info("Context KeyID", value) + info := map[string]string{ + "name": name, + } + + var result string + err := workflow.ExecuteActivity(ctx, TenantActivity, info).Get(ctx, &result) + if err != nil { + logger.Error("TenantActivity failed.", "Error", err) + return "", err + } + + //logger.Info("TenantWorkflow.", "result", result) + + return result, nil +} + +func TenantActivity(ctx context.Context, info map[string]string) (string, error) { + logger := activity.GetLogger(ctx) + value, ok := context.Context.Value(ctx, PropagateKey).(CryptContext) + if !ok { + logger.Error("Activity Unable to retrieve context") + return "", fmt.Errorf("Unable to retrieve context") + } + fmt.Println("Activity Context Value:", value) + //logger.Info("TenantActivity", "info", info) + + name, ok := info["name"] + if !ok { + name = "someone" + } + + return "Hello " + name + "!", nil +} diff --git a/context-aware-encryption/workflow_test.go b/context-aware-encryption/workflow_test.go new file mode 100644 index 00000000..138d295d --- /dev/null +++ b/context-aware-encryption/workflow_test.go @@ -0,0 +1,25 @@ +package contextawareencryption + +import ( + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/testsuite" +) + +func Test_Workflow(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + // Mock activity implementation + env.OnActivity(TenantActivity, mock.Anything, mock.Anything).Return("Hello Temporal!", nil) + + env.ExecuteWorkflow(TenantWorkflow, "Temporal") + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var result string + require.NoError(t, env.GetWorkflowResult(&result)) + require.Equal(t, "Hello Temporal!", result) +}