Skip to content

Commit 5dc8817

Browse files
committed
Add OAuth support for CLI extensions
1 parent 959f910 commit 5dc8817

File tree

17 files changed

+1492
-136
lines changed

17 files changed

+1492
-136
lines changed

cliext/config.go

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
package cliext
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"sort"
9+
"time"
10+
11+
"github.com/BurntSushi/toml"
12+
"go.temporal.io/sdk/contrib/envconfig"
13+
)
14+
15+
type ClientConfig struct {
16+
Profiles map[string]*Profile
17+
}
18+
19+
type LoadConfigOptions struct {
20+
// Override the file path to use to load the TOML file for config. Defaults to TEMPORAL_CONFIG_FILE environment
21+
// variable or if that is unset/empty, defaults to [os.UserConfigDir]/temporal/temporal.toml. If ConfigFileData is
22+
// set, this cannot be set and no file loading from disk occurs. Ignored if DisableFile is true.
23+
ConfigFilePath string
24+
25+
// Override the environment variable lookup. If nil, defaults to [EnvLookupOS].
26+
EnvLookup envconfig.EnvLookup
27+
}
28+
29+
type LoadConfigResult struct {
30+
// Config is the loaded configuration with its profiles.
31+
Config ClientConfig
32+
33+
// ConfigFilePath is the resolved path to the configuration file that was loaded.
34+
// This may differ from the input if TEMPORAL_CONFIG_FILE env var was used.
35+
ConfigFilePath string
36+
}
37+
38+
// oauthConfigTOML is the TOML representation of OAuthConfig.
39+
// We use a separate struct to control TOML field names and handle time.Time.
40+
type oauthConfigTOML struct {
41+
ClientID string `toml:"client_id,omitempty"`
42+
ClientSecret string `toml:"client_secret,omitempty"`
43+
TokenURL string `toml:"token_url,omitempty"`
44+
AuthURL string `toml:"auth_url,omitempty"`
45+
RedirectURL string `toml:"redirect_url,omitempty"`
46+
AccessToken string `toml:"access_token,omitempty"`
47+
RefreshToken string `toml:"refresh_token,omitempty"`
48+
TokenType string `toml:"token_type,omitempty"`
49+
ExpiresAt string `toml:"expires_at,omitempty"`
50+
Scopes []string `toml:"scopes,omitempty"`
51+
RequestParams map[string]string `toml:"request_params,omitempty"`
52+
}
53+
54+
type rawProfileWithOAuth struct {
55+
OAuth *oauthConfigTOML `toml:"oauth"`
56+
}
57+
58+
type rawConfigWithOAuth struct {
59+
Profile map[string]*rawProfileWithOAuth `toml:"profile"`
60+
}
61+
62+
// LoadConfig loads the client configuration from the specified file or default location.
63+
// If ConfigFilePath is empty, the TEMPORAL_CONFIG_FILE environment variable is checked.
64+
func LoadConfig(options LoadConfigOptions) (LoadConfigResult, error) {
65+
envLookup := options.EnvLookup
66+
if envLookup == nil {
67+
envLookup = envconfig.EnvLookupOS
68+
}
69+
70+
configFilePath := options.ConfigFilePath
71+
if configFilePath == "" {
72+
configFilePath, _ = envLookup.LookupEnv("TEMPORAL_CONFIG_FILE")
73+
}
74+
75+
clientConfig, err := envconfig.LoadClientConfig(envconfig.LoadClientConfigOptions{
76+
ConfigFilePath: configFilePath,
77+
EnvLookup: envLookup,
78+
})
79+
if err != nil {
80+
return LoadConfigResult{}, err
81+
}
82+
83+
// Load OAuth for all profiles by parsing the config file directly
84+
var oauthByProfile map[string]*OAuthConfig
85+
configFilePathForOAuth := configFilePath
86+
if configFilePathForOAuth == "" {
87+
configFilePathForOAuth, _ = envconfig.DefaultConfigFilePath()
88+
}
89+
if configFilePathForOAuth != "" {
90+
data, err := os.ReadFile(configFilePathForOAuth)
91+
if err == nil {
92+
var raw rawConfigWithOAuth
93+
if _, err := toml.Decode(string(data), &raw); err == nil {
94+
oauthByProfile = make(map[string]*OAuthConfig)
95+
for profileName, profile := range raw.Profile {
96+
if profile == nil || profile.OAuth == nil {
97+
continue
98+
}
99+
cfg := profile.OAuth
100+
oauth := &OAuthConfig{
101+
OAuthClientConfig: OAuthClientConfig{
102+
ClientID: cfg.ClientID,
103+
ClientSecret: cfg.ClientSecret,
104+
TokenURL: cfg.TokenURL,
105+
AuthURL: cfg.AuthURL,
106+
RedirectURL: cfg.RedirectURL,
107+
RequestParams: cfg.RequestParams,
108+
Scopes: cfg.Scopes,
109+
},
110+
OAuthToken: OAuthToken{
111+
AccessToken: cfg.AccessToken,
112+
RefreshToken: cfg.RefreshToken,
113+
TokenType: cfg.TokenType,
114+
},
115+
}
116+
if cfg.ExpiresAt != "" {
117+
if t, err := time.Parse(time.RFC3339, cfg.ExpiresAt); err == nil {
118+
oauth.ExpiresAt = t
119+
}
120+
}
121+
oauthByProfile[profileName] = oauth
122+
}
123+
}
124+
}
125+
}
126+
127+
// Build profiles map combining base config and OAuth
128+
profiles := make(map[string]*Profile)
129+
for name, baseProfile := range clientConfig.Profiles {
130+
profiles[name] = &Profile{
131+
ClientConfigProfile: *baseProfile,
132+
OAuth: oauthByProfile[name],
133+
}
134+
}
135+
136+
// Add any profiles that only have OAuth config
137+
for name, oauth := range oauthByProfile {
138+
if _, exists := profiles[name]; !exists {
139+
profiles[name] = &Profile{
140+
OAuth: oauth,
141+
}
142+
}
143+
}
144+
145+
return LoadConfigResult{
146+
Config: ClientConfig{Profiles: profiles},
147+
ConfigFilePath: configFilePath,
148+
}, nil
149+
}
150+
151+
// WriteConfigOptions contains options for writing configuration.
152+
type WriteConfigOptions struct {
153+
// Config is the configuration to write.
154+
Config ClientConfig
155+
156+
// ConfigFilePath is the path to write the configuration file to.
157+
// If empty, TEMPORAL_CONFIG_FILE env var is checked, then the default path is used.
158+
ConfigFilePath string
159+
160+
// EnvLookup is used for environment variable lookups.
161+
// If nil, envconfig.EnvLookupOS is used.
162+
EnvLookup envconfig.EnvLookup
163+
}
164+
165+
// WriteConfigToBytes serializes the configuration to TOML bytes.
166+
func WriteConfigToBytes(config *ClientConfig) ([]byte, error) {
167+
profiles := config.Profiles
168+
if profiles == nil {
169+
profiles = make(map[string]*Profile)
170+
}
171+
172+
// Build envconfig.ClientConfig from profiles
173+
envConfig := &envconfig.ClientConfig{
174+
Profiles: make(map[string]*envconfig.ClientConfigProfile),
175+
}
176+
for name, p := range profiles {
177+
if p != nil {
178+
envConfig.Profiles[name] = &p.ClientConfigProfile
179+
}
180+
}
181+
182+
// Convert base config to TOML
183+
b, err := envConfig.ToTOML(envconfig.ClientConfigToTOMLOptions{})
184+
if err != nil {
185+
return nil, fmt.Errorf("failed building TOML: %w", err)
186+
}
187+
188+
// Append OAuth sections per profile
189+
var buf bytes.Buffer
190+
buf.Write(b)
191+
192+
// Sort keys for deterministic output
193+
profileNames := make([]string, 0, len(profiles))
194+
for name := range profiles {
195+
profileNames = append(profileNames, name)
196+
}
197+
sort.Strings(profileNames)
198+
199+
for _, profileName := range profileNames {
200+
profile := profiles[profileName]
201+
if profile == nil || profile.OAuth == nil {
202+
continue
203+
}
204+
205+
// Convert OAuthConfig to oauthConfigTOML (without RequestParams - handled separately)
206+
oauthTOML := &oauthConfigTOML{
207+
ClientID: profile.OAuth.ClientID,
208+
ClientSecret: profile.OAuth.ClientSecret,
209+
TokenURL: profile.OAuth.TokenURL,
210+
AuthURL: profile.OAuth.AuthURL,
211+
RedirectURL: profile.OAuth.RedirectURL,
212+
AccessToken: profile.OAuth.AccessToken,
213+
RefreshToken: profile.OAuth.RefreshToken,
214+
TokenType: profile.OAuth.TokenType,
215+
Scopes: profile.OAuth.Scopes,
216+
// RequestParams omitted - toml.Marshal creates a separate table section which breaks nesting
217+
}
218+
if !profile.OAuth.ExpiresAt.IsZero() {
219+
oauthTOML.ExpiresAt = profile.OAuth.ExpiresAt.Format(time.RFC3339)
220+
}
221+
222+
// Marshal to TOML
223+
oauthBytes, err := toml.Marshal(oauthTOML)
224+
if err != nil {
225+
return nil, fmt.Errorf("failed marshaling OAuth config: %w", err)
226+
}
227+
228+
// Write the section header and content
229+
fmt.Fprintf(&buf, "\n[profile.%s.oauth]\n", profileName)
230+
buf.Write(oauthBytes)
231+
232+
// Manually write request_params as inline table (toml.Marshal would create a separate [request_params] section)
233+
if len(profile.OAuth.RequestParams) > 0 {
234+
keys := make([]string, 0, len(profile.OAuth.RequestParams))
235+
for k := range profile.OAuth.RequestParams {
236+
keys = append(keys, k)
237+
}
238+
sort.Strings(keys)
239+
buf.WriteString("request_params = { ")
240+
for i, k := range keys {
241+
if i > 0 {
242+
buf.WriteString(", ")
243+
}
244+
fmt.Fprintf(&buf, "%s = %q", k, profile.OAuth.RequestParams[k])
245+
}
246+
buf.WriteString(" }\n")
247+
}
248+
}
249+
return buf.Bytes(), nil
250+
}
251+
252+
// WriteConfig writes the configuration to the specified file or default location.
253+
//
254+
// Example:
255+
//
256+
// result, _ := cliext.LoadProfile(cliext.LoadProfileOptions{CreateIfMissing: true})
257+
// result.Profile.OAuth = oauthConfig
258+
// err := cliext.WriteConfig(cliext.WriteConfigOptions{Config: result.Config})
259+
func WriteConfig(opts WriteConfigOptions) error {
260+
envLookup := opts.EnvLookup
261+
if envLookup == nil {
262+
envLookup = envconfig.EnvLookupOS
263+
}
264+
265+
configFilePath := opts.ConfigFilePath
266+
if configFilePath == "" {
267+
configFilePath, _ = envLookup.LookupEnv("TEMPORAL_CONFIG_FILE")
268+
if configFilePath == "" {
269+
var err error
270+
if configFilePath, err = envconfig.DefaultConfigFilePath(); err != nil {
271+
return err
272+
}
273+
}
274+
}
275+
276+
b, err := WriteConfigToBytes(&opts.Config)
277+
if err != nil {
278+
return err
279+
}
280+
281+
// Write to file, making dirs as needed
282+
if err := os.MkdirAll(filepath.Dir(configFilePath), 0700); err != nil {
283+
return fmt.Errorf("failed making config file parent dirs: %w", err)
284+
}
285+
if err := os.WriteFile(configFilePath, b, 0600); err != nil {
286+
return fmt.Errorf("failed writing config file: %w", err)
287+
}
288+
return nil
289+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package devserver
1+
package cliext
22

33
import (
44
"fmt"
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
package devserver_test
1+
package cliext
22

33
import (
44
"fmt"
55
"net"
66
"testing"
7-
8-
"github.com/temporalio/cli/internal/devserver"
97
)
108

119
func TestFreePort_NoDouble(t *testing.T) {
1210
host := "127.0.0.1"
1311
portSet := make(map[int]bool)
1412
for i := 0; i < 2000; i++ {
15-
p, err := devserver.GetFreePort(host)
13+
p, err := GetFreePort(host)
1614
if err != nil {
1715
t.Fatalf("Error: %s", err)
1816
break
@@ -30,7 +28,7 @@ func TestFreePort_NoDouble(t *testing.T) {
3028
func TestFreePort_CanBindImmediatelySameProcess(t *testing.T) {
3129
host := "127.0.0.1"
3230
for i := 0; i < 500; i++ {
33-
p, err := devserver.GetFreePort(host)
31+
p, err := GetFreePort(host)
3432
if err != nil {
3533
t.Fatalf("Error: %s", err)
3634
break
@@ -45,7 +43,7 @@ func TestFreePort_CanBindImmediatelySameProcess(t *testing.T) {
4543

4644
func TestFreePort_IPv4Unspecified(t *testing.T) {
4745
host := "0.0.0.0"
48-
p, err := devserver.GetFreePort(host)
46+
p, err := GetFreePort(host)
4947
if err != nil {
5048
t.Fatalf("Error: %s", err)
5149
return
@@ -59,7 +57,7 @@ func TestFreePort_IPv4Unspecified(t *testing.T) {
5957

6058
func TestFreePort_IPv6Unspecified(t *testing.T) {
6159
host := "::"
62-
p, err := devserver.GetFreePort(host)
60+
p, err := GetFreePort(host)
6361
if err != nil {
6462
t.Fatalf("Error: %s", err)
6563
return
@@ -73,7 +71,7 @@ func TestFreePort_IPv6Unspecified(t *testing.T) {
7371

7472
// This function is used as part of unit tests, to ensure that the port
7573
func tryListenAndDialOn(host string, port int) error {
76-
host = devserver.MaybeEscapeIPv6(host)
74+
host = MaybeEscapeIPv6(host)
7775
l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port))
7876
if err != nil {
7977
return err

cliext/go.mod

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
module github.com/temporalio/cli/cliext
2+
3+
go 1.24.0
4+
5+
require (
6+
github.com/spf13/pflag v1.0.9
7+
github.com/stretchr/testify v1.10.0
8+
go.temporal.io/api v1.44.1
9+
go.temporal.io/sdk v1.32.1
10+
go.temporal.io/sdk/contrib/envconfig v0.1.0
11+
google.golang.org/grpc v1.66.0
12+
)
13+
14+
require (
15+
github.com/BurntSushi/toml v1.4.0 // indirect
16+
github.com/davecgh/go-spew v1.1.1 // indirect
17+
github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect
18+
github.com/gogo/protobuf v1.3.2 // indirect
19+
github.com/golang/mock v1.6.0 // indirect
20+
github.com/google/uuid v1.6.0 // indirect
21+
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
22+
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
23+
github.com/inconshreveable/mousetrap v1.1.0 // indirect
24+
github.com/mattn/go-isatty v0.0.20 // indirect
25+
github.com/nexus-rpc/sdk-go v0.3.0 // indirect
26+
github.com/pborman/uuid v1.2.1 // indirect
27+
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
28+
github.com/pmezard/go-difflib v1.0.0 // indirect
29+
github.com/robfig/cron v1.2.0 // indirect
30+
github.com/spf13/cobra v1.10.2 // indirect
31+
github.com/stretchr/objx v0.5.2 // indirect
32+
golang.org/x/exp v0.0.0-20231127185646-65229373498e // indirect
33+
golang.org/x/net v0.28.0 // indirect
34+
golang.org/x/oauth2 v0.33.0 // indirect
35+
golang.org/x/sync v0.8.0 // indirect
36+
golang.org/x/sys v0.24.0 // indirect
37+
golang.org/x/text v0.17.0 // indirect
38+
golang.org/x/time v0.3.0 // indirect
39+
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
40+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect
41+
google.golang.org/protobuf v1.34.2 // indirect
42+
gopkg.in/yaml.v3 v3.0.1 // indirect
43+
)

0 commit comments

Comments
 (0)