Skip to content

Commit e7b814c

Browse files
committed
Add OAuth support for CLI extensions
1 parent 32e30b1 commit e7b814c

File tree

17 files changed

+1495
-142
lines changed

17 files changed

+1495
-142
lines changed

cliext/config.go

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
package cliext
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"path/filepath"
9+
"sort"
10+
"time"
11+
12+
"github.com/BurntSushi/toml"
13+
"go.temporal.io/sdk/contrib/envconfig"
14+
)
15+
16+
type ClientConfig struct {
17+
Profiles map[string]*Profile
18+
}
19+
20+
type LoadConfigOptions struct {
21+
// Override the file path to use to load the TOML file for config. Defaults to TEMPORAL_CONFIG_FILE environment
22+
// variable or if that is unset/empty, defaults to [os.UserConfigDir]/temporal/temporal.toml. If ConfigFileData is
23+
// set, this cannot be set and no file loading from disk occurs. Ignored if DisableFile is true.
24+
ConfigFilePath string
25+
// Override the environment variable lookup. If nil, defaults to [EnvLookupOS].
26+
EnvLookup envconfig.EnvLookup
27+
}
28+
29+
type LoadConfigResult struct {
30+
// Config is the loaded configuration with its profiles.
31+
Config ClientConfig
32+
// ConfigFilePath is the resolved path to the configuration file that was loaded.
33+
// This may differ from the input if TEMPORAL_CONFIG_FILE env var was used.
34+
ConfigFilePath string
35+
}
36+
37+
// oauthConfigTOML is the TOML representation of OAuthConfig.
38+
type oauthConfigTOML struct {
39+
ClientID string `toml:"client_id,omitempty"`
40+
ClientSecret string `toml:"client_secret,omitempty"`
41+
TokenURL string `toml:"token_url,omitempty"`
42+
AuthURL string `toml:"auth_url,omitempty"`
43+
AccessToken string `toml:"access_token,omitempty"`
44+
RefreshToken string `toml:"refresh_token,omitempty"`
45+
TokenType string `toml:"token_type,omitempty"`
46+
ExpiresAt string `toml:"expires_at,omitempty"`
47+
Scopes []string `toml:"scopes,omitempty"`
48+
RequestParams inlineStringMap `toml:"request_params,omitempty"`
49+
}
50+
51+
type rawProfileWithOAuth struct {
52+
OAuth *oauthConfigTOML `toml:"oauth"`
53+
}
54+
55+
type rawConfigWithOAuth struct {
56+
Profile map[string]*rawProfileWithOAuth `toml:"profile"`
57+
}
58+
59+
// LoadConfig loads the client configuration from the specified file or default location.
60+
// If ConfigFilePath is empty, the TEMPORAL_CONFIG_FILE environment variable is checked,
61+
// then the default path is used.
62+
func LoadConfig(options LoadConfigOptions) (LoadConfigResult, error) {
63+
envLookup := options.EnvLookup
64+
if envLookup == nil {
65+
envLookup = envconfig.EnvLookupOS
66+
}
67+
68+
// Resolve the config file path.
69+
resolvedPath := options.ConfigFilePath
70+
if resolvedPath == "" {
71+
resolvedPath, _ = envLookup.LookupEnv("TEMPORAL_CONFIG_FILE")
72+
}
73+
if resolvedPath == "" {
74+
var err error
75+
resolvedPath, err = envconfig.DefaultConfigFilePath()
76+
if err != nil {
77+
return LoadConfigResult{}, fmt.Errorf("failed to get default config path: %w", err)
78+
}
79+
}
80+
81+
clientConfig, err := envconfig.LoadClientConfig(envconfig.LoadClientConfigOptions{
82+
ConfigFilePath: resolvedPath,
83+
EnvLookup: envLookup,
84+
})
85+
if err != nil {
86+
return LoadConfigResult{}, err
87+
}
88+
89+
// Load OAuth for all profiles by parsing the config file directly.
90+
oauthByProfile, err := loadOAuthConfigFromFile(resolvedPath)
91+
if err != nil {
92+
return LoadConfigResult{}, err
93+
}
94+
95+
// Merge profiles and their OAuth configurations.
96+
profiles := make(map[string]*Profile)
97+
for name, baseProfile := range clientConfig.Profiles {
98+
profiles[name] = &Profile{
99+
ClientConfigProfile: *baseProfile,
100+
OAuth: oauthByProfile[name],
101+
}
102+
}
103+
104+
return LoadConfigResult{
105+
Config: ClientConfig{Profiles: profiles},
106+
ConfigFilePath: resolvedPath,
107+
}, nil
108+
}
109+
110+
func loadOAuthConfigFromFile(path string) (map[string]*OAuthConfig, error) {
111+
data, err := os.ReadFile(path)
112+
if err != nil {
113+
if errors.Is(err, os.ErrNotExist) {
114+
return nil, nil
115+
}
116+
return nil, fmt.Errorf("failed to read config file: %w", err)
117+
}
118+
119+
var raw rawConfigWithOAuth
120+
if _, err := toml.Decode(string(data), &raw); err != nil {
121+
return nil, fmt.Errorf("failed to parse config file: %w", err)
122+
}
123+
124+
oauthByProfile := make(map[string]*OAuthConfig)
125+
for profileName, profile := range raw.Profile {
126+
if profile == nil || profile.OAuth == nil {
127+
oauthByProfile[profileName] = nil
128+
continue
129+
}
130+
cfg := profile.OAuth
131+
oauth := &OAuthConfig{
132+
OAuthClientConfig: OAuthClientConfig{
133+
ClientID: cfg.ClientID,
134+
ClientSecret: cfg.ClientSecret,
135+
TokenURL: cfg.TokenURL,
136+
AuthURL: cfg.AuthURL,
137+
RequestParams: cfg.RequestParams,
138+
Scopes: cfg.Scopes,
139+
},
140+
OAuthToken: OAuthToken{
141+
AccessToken: cfg.AccessToken,
142+
RefreshToken: cfg.RefreshToken,
143+
TokenType: cfg.TokenType,
144+
},
145+
}
146+
if cfg.ExpiresAt != "" {
147+
t, err := time.Parse(time.RFC3339, cfg.ExpiresAt)
148+
if err != nil {
149+
return nil, fmt.Errorf("failed to parse expires_at for profile %q: %w", profileName, err)
150+
}
151+
oauth.AccessTokenExpiresAt = t
152+
}
153+
oauthByProfile[profileName] = oauth
154+
}
155+
return oauthByProfile, nil
156+
}
157+
158+
type WriteConfigOptions struct {
159+
// Config is the configuration to write.
160+
Config ClientConfig
161+
// ConfigFilePath is the path to write the configuration file to.
162+
// If empty, TEMPORAL_CONFIG_FILE env var is checked, then the default path is used.
163+
ConfigFilePath string
164+
// Override the environment variable lookup. If nil, defaults to [EnvLookupOS].
165+
EnvLookup envconfig.EnvLookup
166+
}
167+
168+
// ConfigToTOML serializes the configuration to TOML bytes.
169+
func ConfigToTOML(config *ClientConfig) ([]byte, error) {
170+
// Build envconfig.ClientConfig from profiles.
171+
envConfig := &envconfig.ClientConfig{
172+
Profiles: make(map[string]*envconfig.ClientConfigProfile),
173+
}
174+
for name, p := range config.Profiles {
175+
if p != nil {
176+
envConfig.Profiles[name] = &p.ClientConfigProfile
177+
}
178+
}
179+
180+
// Convert base config to TOML.
181+
b, err := envConfig.ToTOML(envconfig.ClientConfigToTOMLOptions{})
182+
if err != nil {
183+
return nil, fmt.Errorf("failed building TOML: %w", err)
184+
}
185+
186+
// Append OAuth sections per profile.
187+
var buf bytes.Buffer
188+
buf.Write(b)
189+
190+
for name, profile := range config.Profiles {
191+
if profile == nil || profile.OAuth == nil {
192+
continue
193+
}
194+
oauthTOML := &oauthConfigTOML{
195+
ClientID: profile.OAuth.ClientID,
196+
ClientSecret: profile.OAuth.ClientSecret,
197+
TokenURL: profile.OAuth.TokenURL,
198+
AuthURL: profile.OAuth.AuthURL,
199+
AccessToken: profile.OAuth.AccessToken,
200+
RefreshToken: profile.OAuth.RefreshToken,
201+
TokenType: profile.OAuth.TokenType,
202+
Scopes: profile.OAuth.Scopes,
203+
RequestParams: profile.OAuth.RequestParams,
204+
}
205+
if !profile.OAuth.AccessTokenExpiresAt.IsZero() {
206+
oauthTOML.ExpiresAt = profile.OAuth.AccessTokenExpiresAt.Format(time.RFC3339)
207+
}
208+
209+
oauthBytes, err := toml.Marshal(oauthTOML)
210+
if err != nil {
211+
return nil, fmt.Errorf("failed marshaling OAuth config: %w", err)
212+
}
213+
fmt.Fprintf(&buf, "\n[profile.%s.oauth]\n", name)
214+
buf.Write(oauthBytes)
215+
}
216+
return buf.Bytes(), nil
217+
}
218+
219+
// WriteConfig writes the environment configuration to disk.
220+
func WriteConfig(opts WriteConfigOptions) error {
221+
envLookup := opts.EnvLookup
222+
if envLookup == nil {
223+
envLookup = envconfig.EnvLookupOS
224+
}
225+
226+
configFilePath := opts.ConfigFilePath
227+
if configFilePath == "" {
228+
configFilePath, _ = envLookup.LookupEnv("TEMPORAL_CONFIG_FILE")
229+
if configFilePath == "" {
230+
var err error
231+
if configFilePath, err = envconfig.DefaultConfigFilePath(); err != nil {
232+
return err
233+
}
234+
}
235+
}
236+
237+
b, err := ConfigToTOML(&opts.Config)
238+
if err != nil {
239+
return err
240+
}
241+
242+
// Write to file, making dirs as needed.
243+
if err := os.MkdirAll(filepath.Dir(configFilePath), 0700); err != nil {
244+
return fmt.Errorf("failed making config file parent dirs: %w", err)
245+
}
246+
if err := os.WriteFile(configFilePath, b, 0600); err != nil {
247+
return fmt.Errorf("failed writing config file: %w", err)
248+
}
249+
return nil
250+
}
251+
252+
// inlineStringMap wraps a map to marshal as an inline TOML table.
253+
type inlineStringMap map[string]string
254+
255+
func (m inlineStringMap) MarshalTOML() ([]byte, error) {
256+
if len(m) == 0 {
257+
return nil, nil
258+
}
259+
260+
// Sort keys for deterministic output.
261+
keys := make([]string, 0, len(m))
262+
for k := range m {
263+
keys = append(keys, k)
264+
}
265+
sort.Strings(keys)
266+
267+
var buf bytes.Buffer
268+
buf.WriteString("{ ")
269+
for i, k := range keys {
270+
if i > 0 {
271+
buf.WriteString(", ")
272+
}
273+
fmt.Fprintf(&buf, "%s = %q", k, m[k])
274+
}
275+
buf.WriteString(" }")
276+
return buf.Bytes(), nil
277+
}
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package devserver
1+
package cliext
22

33
import (
44
"fmt"
@@ -24,13 +24,13 @@ func GetFreePort(host string) (int, error) {
2424
host = MaybeEscapeIPv6(host)
2525
l, err := net.Listen("tcp", host+":0")
2626
if err != nil {
27-
return 0, fmt.Errorf("failed to assign a free port: %v", err)
27+
return 0, fmt.Errorf("failed to assign a free port: %w", err)
2828
}
2929
defer l.Close()
3030
port := l.Addr().(*net.TCPAddr).Port
3131

3232
// On Linux and some BSD variants, ephemeral ports are randomized, and may
33-
// consequently repeat within a short time frame after the listenning end
33+
// consequently repeat within a short time frame after the listening end
3434
// has been closed. To avoid this, we make a connection to the port, then
3535
// close that connection from the server's side (this is very important),
3636
// which puts the connection in TIME_WAIT state for some time (by default,
@@ -50,17 +50,17 @@ func GetFreePort(host string) (int, error) {
5050
// to ::1). For safety, rebuild address form the original host instead.
5151
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", host, port))
5252
if err != nil {
53-
return 0, fmt.Errorf("error resolving address: %v", err)
53+
return 0, fmt.Errorf("error resolving address: %w", err)
5454
}
5555
r, err := net.DialTCP("tcp", nil, tcpAddr)
5656
if err != nil {
57-
return 0, fmt.Errorf("failed to assign a free port: %v", err)
57+
return 0, fmt.Errorf("failed to assign a free port: %w", err)
5858
}
5959
c, err := l.Accept()
6060
if err != nil {
61-
return 0, fmt.Errorf("failed to assign a free port: %v", err)
61+
return 0, fmt.Errorf("failed to assign a free port: %w", err)
6262
}
63-
// Closing the socket from the server side
63+
// Closing the socket from the server side.
6464
c.Close()
6565
defer r.Close()
6666
}
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
package devserver_test
1+
package cliext_test
22

33
import (
44
"fmt"
55
"net"
66
"testing"
77

8-
"github.com/temporalio/cli/internal/devserver"
8+
"github.com/temporalio/cli/cliext"
99
)
1010

1111
func TestFreePort_NoDouble(t *testing.T) {
1212
host := "127.0.0.1"
1313
portSet := make(map[int]bool)
1414
for i := 0; i < 2000; i++ {
15-
p, err := devserver.GetFreePort(host)
15+
p, err := cliext.GetFreePort(host)
1616
if err != nil {
1717
t.Fatalf("Error: %s", err)
1818
break
@@ -30,7 +30,7 @@ func TestFreePort_NoDouble(t *testing.T) {
3030
func TestFreePort_CanBindImmediatelySameProcess(t *testing.T) {
3131
host := "127.0.0.1"
3232
for i := 0; i < 500; i++ {
33-
p, err := devserver.GetFreePort(host)
33+
p, err := cliext.GetFreePort(host)
3434
if err != nil {
3535
t.Fatalf("Error: %s", err)
3636
break
@@ -45,7 +45,7 @@ func TestFreePort_CanBindImmediatelySameProcess(t *testing.T) {
4545

4646
func TestFreePort_IPv4Unspecified(t *testing.T) {
4747
host := "0.0.0.0"
48-
p, err := devserver.GetFreePort(host)
48+
p, err := cliext.GetFreePort(host)
4949
if err != nil {
5050
t.Fatalf("Error: %s", err)
5151
return
@@ -59,7 +59,7 @@ func TestFreePort_IPv4Unspecified(t *testing.T) {
5959

6060
func TestFreePort_IPv6Unspecified(t *testing.T) {
6161
host := "::"
62-
p, err := devserver.GetFreePort(host)
62+
p, err := cliext.GetFreePort(host)
6363
if err != nil {
6464
t.Fatalf("Error: %s", err)
6565
return
@@ -72,8 +72,9 @@ func TestFreePort_IPv6Unspecified(t *testing.T) {
7272
}
7373

7474
// This function is used as part of unit tests, to ensure that the port
75+
// is available for listening and dialing.
7576
func tryListenAndDialOn(host string, port int) error {
76-
host = devserver.MaybeEscapeIPv6(host)
77+
host = cliext.MaybeEscapeIPv6(host)
7778
l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port))
7879
if err != nil {
7980
return err

0 commit comments

Comments
 (0)