Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions context-aware-encryption/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
### Steps to run this sample:
1) Run a [Temporal service](https://github.com/temporalio/samples-go/tree/main/#how-to-use).
2) Run the following command to start the remote codec server
```
go run ./codec-server
```
3) Run the following command to start the worker
```
go run worker/main.go
```
4) Run the following command to start the example
```
go run starter/main.go
```
5) Run the following command and see the payloads cannot be decoded
```
tctl workflow show --wid encryption_workflowID
```
6) Run the following command and see the decoded payloads
```
tctl --codec_endpoint 'http://localhost:8081/' workflow show --wid encryption_workflowID
```

Note: The codec server provided in this sample does not support decoding payloads for the Temporal Web UI, only tctl.
Please see the [codec-server](../codec-server/) sample for a more complete example of a codec server which provides UI decoding and oauth.
146 changes: 146 additions & 0 deletions context-aware-encryption/app/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package main

import (
"context"
"fmt"
contextawareencryption "github.com/temporalio/samples-go/context-aware-encryption"
sdkclient "go.temporal.io/sdk/client"
"go.temporal.io/sdk/worker"
"golang.org/x/sync/errgroup"
"log"
"os"
"os/signal"
"syscall"
"time"
)

type startable interface {
Start(context.Context) error
Shutdown(context.Context)
}

func main() {
ctx, done := context.WithCancel(context.Background())

c := contextawareencryption.MustGetDefaultTemporalClient(ctx, nil)
defer c.Close()
g, ctx := errgroup.WithContext(ctx)

// set up signal listener
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(quit)

startables := []startable{
&Worker{tclient: c},
&App{tclient: c, maxCount: len(contextawareencryption.TenantKeysByOrganization)},
}

for _, s := range startables {
var current = s
g.Go(func() error {
if err := current.Start(ctx); err != nil {
return err
}
return nil
})
}

select {
case <-quit:
break
case <-ctx.Done():
break
}

// shutdown the things
done()
// limit how long we'll wait for
timeoutCtx, timeoutCancel := context.WithTimeout(
context.Background(),
10*time.Second,
)
defer timeoutCancel()

for _, s := range startables {
s.Shutdown(timeoutCtx)
}

// wait for shutdown
if err := g.Wait(); err != nil {
panic("shutdown was not clean" + err.Error())
}
}

type Worker struct {
tclient sdkclient.Client
worker worker.Worker
}

func (w *Worker) Start(ctx context.Context) error {
w.worker = worker.New(w.tclient, "encryption", worker.Options{})

w.worker.RegisterWorkflow(contextawareencryption.TenantWorkflow)
w.worker.RegisterActivity(contextawareencryption.TenantActivity)

return w.worker.Run(worker.InterruptCh())
}
func (w *Worker) Shutdown(ctx context.Context) {
w.worker.Stop()
}

type App struct {
tclient sdkclient.Client
maxCount int
}

func (a *App) Shutdown(ctx context.Context) {

}
func (a *App) Start(ctx context.Context) error {
if a.maxCount == 0 {
return fmt.Errorf("You must at least one run Workflow")

Check failure on line 102 in context-aware-encryption/app/main.go

View workflow job for this annotation

GitHub Actions / build-and-test

error strings should not be capitalized (ST1005)
}
dt := time.Now().UTC().String()
count := 0
for tenant, keyId := range contextawareencryption.TenantKeysByOrganization {
wid := fmt.Sprintf("tenant_%s-%s", tenant, dt)
workflowOptions := sdkclient.StartWorkflowOptions{
ID: wid,
TaskQueue: "encryption",
}

// If you are using a ContextPropagator and varying keys per workflow you need to set
// the KeyID to use for this workflow in the context:
fmt.Println(fmt.Sprintf("Setting encryption key for '%s' with value '%s'", tenant, keyId))

Check failure on line 115 in context-aware-encryption/app/main.go

View workflow job for this annotation

GitHub Actions / build-and-test

should use fmt.Printf instead of fmt.Println(fmt.Sprintf(...)) (but don't forget the newline) (S1038)
ctx = context.WithValue(ctx,
contextawareencryption.PropagateKey,
contextawareencryption.CryptContext{KeyID: keyId})

// The workflow input tenant will be encrypted by the DataConverter before being sent to Temporal
we, err := a.tclient.ExecuteWorkflow(
ctx,
workflowOptions,
contextawareencryption.TenantWorkflow,
"workflowargument for "+tenant,
)
if err != nil {
log.Fatalln("Unable to execute workflow", err)
}

log.Println("Started workflow", "WorkflowID", we.GetID(), "RunID", we.GetRunID())

// Synchronously wait for the workflow completion.
var result string
err = we.Get(context.Background(), &result)
if err != nil {
log.Fatalln("Unable get workflow result", err)
}
log.Println("TenantWorkflow result:", result)
count++
if count >= a.maxCount {
break
}
}
return nil
}
47 changes: 47 additions & 0 deletions context-aware-encryption/codec-server/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"flag"
contextawareencryption "github.com/temporalio/samples-go/context-aware-encryption"
"log"
"net/http"
"os"
"os/signal"
"strconv"

"go.temporal.io/sdk/converter"
)

var portFlag int

func init() {
flag.IntVar(&portFlag, "port", 8081, "Port to listen on")
}

func main() {
flag.Parse()

// This example codec server does not support varying config per namespace,
// decoding for the Temporal Web UI or oauth.
// For a more complete example of a codec server please see the codec-server sample at:
// ../../codec-server.
handler := converter.NewPayloadCodecHTTPHandler(&contextawareencryption.Codec{}, converter.NewZlibCodec(converter.ZlibCodecOptions{AlwaysEncode: true}))

srv := &http.Server{
Addr: "0.0.0.0:" + strconv.Itoa(portFlag),
Handler: handler,
}

errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)

select {
case <-sigCh:
_ = srv.Close()
case err := <-errCh:
log.Fatal(err)
}
}
97 changes: 97 additions & 0 deletions context-aware-encryption/codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package contextawareencryption

import (
"fmt"
commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/sdk/converter"
)

// Codec implements PayloadCodec using AES Crypt.
type Codec struct {
KeyID string
Tenant string
}

func (e *Codec) getKey(keyID string) (key []byte) {
// Key must be fetched from secure storage in production (such as a KMS).
// For testing here we just hard code a key.
result := keyID + "test-key-test-key-test-key-test!"
return []byte(result[0:32])
}

// Encode implements converter.PayloadCodec.Encode.
func (e *Codec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
result := make([]*commonpb.Payload, len(payloads))
for i, p := range payloads {
origBytes, err := p.Marshal()
if err != nil {
return payloads, err
}

key := e.getKey(e.KeyID)
fmt.Println("codec.Encode using tenant/key:", e.KeyID)

b, err := encrypt(origBytes, key)
if err != nil {
return payloads, err
}

result[i] = &commonpb.Payload{
Metadata: map[string][]byte{
converter.MetadataEncoding: []byte(MetadataEncodingEncrypted),
MetadataEncryptionKeyID: []byte(e.KeyID),
MetadataTenant: []byte(e.Tenant),
},
Data: b,
}
}

return result, nil
}

// Decode implements converter.PayloadCodec.Decode.
func (e *Codec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
result := make([]*commonpb.Payload, len(payloads))
for i, p := range payloads {
// Only if it's encrypted
if string(p.Metadata[converter.MetadataEncoding]) != MetadataEncodingEncrypted {
result[i] = p
continue
}

keyID, ok := p.Metadata[MetadataEncryptionKeyID]
if !ok {
return payloads, fmt.Errorf("no encryption key id")
}
tenant, ok := p.Metadata[MetadataTenant]
if !ok {
return payloads, fmt.Errorf("no tenant id")
}
key := e.getKey(string(keyID))
fmt.Println("codec.Decode using tenant/key:", string(tenant), string(key))

b, err := decrypt(p.Data, key)
if err != nil {
return payloads, err
}

result[i] = &commonpb.Payload{}
err = result[i].Unmarshal(b)
if err != nil {
return payloads, err
}
}

return result, nil
}

func GetTenantCodec(key string) *Codec {
availableCodecs := map[string]*Codec{}
for tenant, keyId := range TenantKeysByOrganization {
availableCodecs[keyId] = &Codec{Tenant: tenant, KeyID: keyId}
}
if c, exists := availableCodecs[key]; exists {
return c
}
return &Codec{Tenant: "UNKNOWN TENANT", KeyID: key}
}
48 changes: 48 additions & 0 deletions context-aware-encryption/crypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package contextawareencryption

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
)

func encrypt(plainData []byte, key []byte) ([]byte, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}

gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}

nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}

return gcm.Seal(nonce, nonce, plainData, nil), nil
}

func decrypt(encryptedData []byte, key []byte) ([]byte, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}

gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}

nonceSize := gcm.NonceSize()
if len(encryptedData) < nonceSize {
return nil, fmt.Errorf("ciphertext too short: %v", encryptedData)
}

nonce, encryptedData := encryptedData[:nonceSize], encryptedData[nonceSize:]
return gcm.Open(nil, nonce, encryptedData, nil)
}
Loading
Loading