Skip to content

Commit 7618550

Browse files
authored
Auth rotate only overwrite connection (#2142)
1 parent ff7ba2e commit 7618550

File tree

4 files changed

+66
-148
lines changed

4 files changed

+66
-148
lines changed

src/snowflake/cli/_plugins/auth/keypair/commands.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def setup(
100100

101101
@app.command("rotate", requires_connection=True)
102102
def rotate(
103-
new_connection: bool = _new_connection_option,
104-
connection_name: str = _connection_name_option,
105103
key_length: int = _key_length_option,
106104
output_path: Path = _output_path_option,
107105
private_key_passphrase: SecretType = _private_key_passphrase_option,
@@ -111,7 +109,6 @@ def rotate(
111109
Rotates the key for the connection. Generates the key pair, sets the public key for the user in Snowflake and creates or updates the connection.
112110
"""
113111
AuthManager().rotate(
114-
connection_name=connection_name,
115112
key_length=key_length,
116113
output_path=SecurePath(output_path),
117114
private_key_passphrase=private_key_passphrase,

src/snowflake/cli/_plugins/auth/keypair/manager.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,39 +48,31 @@ def setup(
4848
if not connection_name:
4949
connection_name = cli_context.connection_context.connection_name
5050

51+
key_name = AuthManager._get_free_key_name(output_path, connection_name) # type: ignore[arg-type]
5152
self._generate_key_pair_and_set_public_key(
5253
user=cli_context.connection.user,
5354
key_length=key_length,
5455
output_path=output_path,
55-
connection_name=connection_name, # type: ignore[arg-type]
56+
key_name=key_name, # type: ignore[arg-type]
5657
private_key_passphrase=private_key_passphrase,
5758
)
5859

5960
self._create_or_update_connection(
6061
current_connection=cli_context.connection_context.connection_name,
6162
connection_name=connection_name, # type: ignore[arg-type]
6263
private_key_path=self._get_private_key_path(
63-
output_path=output_path, key_name=connection_name # type: ignore[arg-type]
64+
output_path=output_path, key_name=key_name # type: ignore[arg-type]
6465
),
6566
)
6667

6768
def rotate(
6869
self,
69-
connection_name: str,
7070
key_length: int,
7171
output_path: SecurePath,
7272
private_key_passphrase: SecretType,
7373
):
74-
# When the user provide new connection name
75-
if connection_name and connection_exists(connection_name):
76-
raise ClickException(
77-
f"Connection with name {connection_name} already exists."
78-
)
79-
8074
cli_context = get_cli_context()
81-
# When the use not provide connection name, so we overwrite the current connection
82-
if not connection_name:
83-
connection_name = cli_context.connection_context.connection_name
75+
connection_name = cli_context.connection_context.connection_name
8476

8577
self._ensure_connection_has_private_key(
8678
cli_context.connection_context.connection_name
@@ -98,10 +90,11 @@ def rotate(
9890
public_key_2,
9991
)
10092

93+
key_name = AuthManager._get_free_key_name(output_path, connection_name)
10194
public_key = self._generate_keys_and_return_public_key(
10295
key_length=key_length,
10396
output_path=output_path,
104-
key_name=connection_name,
97+
key_name=key_name,
10598
private_key_passphrase=private_key_passphrase,
10699
)
107100
self.set_public_key(
@@ -111,7 +104,7 @@ def rotate(
111104
current_connection=cli_context.connection_context.connection_name,
112105
connection_name=connection_name,
113106
private_key_path=self._get_private_key_path(
114-
output_path=output_path, key_name=connection_name
107+
output_path=output_path, key_name=key_name
115108
),
116109
)
117110

@@ -120,7 +113,7 @@ def _generate_key_pair_and_set_public_key(
120113
user: str,
121114
key_length: int,
122115
output_path: SecurePath,
123-
connection_name: str,
116+
key_name: str,
124117
private_key_passphrase: SecretType,
125118
):
126119
public_key_exists, public_key_2_exists = self._get_public_keys()
@@ -136,7 +129,7 @@ def _generate_key_pair_and_set_public_key(
136129
public_key = self._generate_keys_and_return_public_key(
137130
key_length=key_length,
138131
output_path=output_path,
139-
key_name=connection_name, # type: ignore[arg-type]
132+
key_name=key_name, # type: ignore[arg-type]
140133
private_key_passphrase=private_key_passphrase,
141134
)
142135
self.set_public_key(user, PublicKeyProperty.RSA_PUBLIC_KEY, public_key)
@@ -192,19 +185,19 @@ def extend_connection_add(
192185
output_path: SecurePath,
193186
private_key_passphrase: SecretType,
194187
) -> Dict:
188+
key_name = AuthManager._get_free_key_name(output_path, connection_name)
189+
195190
self._generate_key_pair_and_set_public_key(
196191
user=connection_options["user"],
197192
key_length=key_length,
198193
output_path=output_path,
199-
connection_name=connection_name,
194+
key_name=key_name,
200195
private_key_passphrase=private_key_passphrase,
201196
)
202197

203198
connection_options["authenticator"] = "SNOWFLAKE_JWT"
204199
connection_options["private_key_file"] = str(
205-
self._get_private_key_path(
206-
output_path=output_path, key_name=connection_name
207-
).path
200+
self._get_private_key_path(output_path=output_path, key_name=key_name).path
208201
)
209202
if connection_options.get("password"):
210203
del connection_options["password"]
@@ -293,6 +286,30 @@ def _generate_keys_and_return_public_key(
293286

294287
return public_pem.decode("utf-8")
295288

289+
@staticmethod
290+
def _get_free_key_name(output_path: SecurePath, key_name: str) -> str:
291+
new_private_key = f"{key_name}.p8"
292+
new_public_key = f"{key_name}.pub"
293+
new_key_name = key_name
294+
counter = 1
295+
296+
while (
297+
(output_path / new_private_key).exists()
298+
and (output_path / new_public_key).exists()
299+
and counter <= 100
300+
):
301+
new_key_name = f"{key_name}_{counter}"
302+
new_private_key = f"{new_key_name}.p8"
303+
new_public_key = f"{new_key_name}.pub"
304+
counter += 1
305+
306+
if counter == 100:
307+
raise ClickException(
308+
"Too many key pairs with the same name in the output directory."
309+
)
310+
311+
return new_key_name
312+
296313
@staticmethod
297314
def _get_private_key_path(output_path: SecurePath, key_name: str) -> SecurePath:
298315
return (output_path / f"{key_name}.p8").resolve()

tests/auth/__snapshots__/test_auth.ambr

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,8 @@
4242

4343
'''
4444
# ---
45-
# name: test_rotate_connection_already_exists
46-
'''
47-
Create a new connection? [Y/n]:
48-
Enter connection name: default
49-
Enter key length [2048]:
50-
Enter private key passphrase []:
51-
+- Error ----------------------------------------------------------------------+
52-
| Connection with name default already exists. |
53-
+------------------------------------------------------------------------------+
54-
55-
'''
56-
# ---
5745
# name: test_rotate_create_output_directory_with_proper_privileges
5846
'''
59-
Create a new connection? [Y/n]: Y
60-
Enter connection name: keypairconnection
6147
Enter key length [2048]: 4096
6248
Enter private key passphrase []:
6349
Rotate completed.
@@ -66,29 +52,23 @@
6652
# ---
6753
# name: test_rotate_no_prompts
6854
'''
69-
Create a new connection? [Y/n]:
70-
Enter connection name: keypairconnection
7155
Set the `PRIVATE_KEY_PASSPHRASE` environment variable before using the connection.
7256
Rotate completed.
7357

7458
'''
7559
# ---
7660
# name: test_rotate_no_public_key_set
7761
'''
78-
Create a new connection? [Y/n]:
79-
Enter connection name: keypairconnection
8062
Enter key length [2048]:
8163
Enter private key passphrase []:
8264
+- Error ----------------------------------------------------------------------+
83-
| Connection with name keypairconnection already exists. |
65+
| No public key found. Use the setup command first. |
8466
+------------------------------------------------------------------------------+
8567

8668
'''
8769
# ---
8870
# name: test_rotate_only_public_key_set
8971
'''
90-
Create a new connection? [Y/n]: Y
91-
Enter connection name: keypairconnection
9272
Enter key length [2048]: 4096
9373
Enter private key passphrase []:
9474
Rotate completed.
@@ -97,8 +77,6 @@
9777
# ---
9878
# name: test_rotate_other_public_key_set_options[KEY-KEY]
9979
'''
100-
Create a new connection? [Y/n]: Y
101-
Enter connection name: keypairconnection
10280
Enter key length [2048]: 4096
10381
Enter private key passphrase []:
10482
Rotate completed.
@@ -107,27 +85,14 @@
10785
# ---
10886
# name: test_rotate_other_public_key_set_options[None-KEY]
10987
'''
110-
Create a new connection? [Y/n]: Y
111-
Enter connection name: keypairconnection
11288
Enter key length [2048]: 4096
11389
Enter private key passphrase []:
11490
Rotate completed.
11591

11692
'''
11793
# ---
118-
# name: test_rotate_overwrite_connection
119-
'''
120-
Create a new connection? [Y/n]: n
121-
Enter key length [2048]:
122-
Enter private key passphrase []:
123-
Rotate completed.
124-
125-
'''
126-
# ---
12794
# name: test_rotate_with_password
12895
'''
129-
Create a new connection? [Y/n]: Y
130-
Enter connection name: keypairconnection
13196
Enter key length [2048]: 4096
13297
Enter private key passphrase []:
13398
Set the `PRIVATE_KEY_PASSPHRASE` environment variable before using the connection.

0 commit comments

Comments
 (0)