@@ -48,39 +48,31 @@ def setup(
48
48
if not connection_name :
49
49
connection_name = cli_context .connection_context .connection_name
50
50
51
+ key_name = AuthManager ._get_free_key_name (output_path , connection_name ) # type: ignore[arg-type]
51
52
self ._generate_key_pair_and_set_public_key (
52
53
user = cli_context .connection .user ,
53
54
key_length = key_length ,
54
55
output_path = output_path ,
55
- connection_name = connection_name , # type: ignore[arg-type]
56
+ key_name = key_name , # type: ignore[arg-type]
56
57
private_key_passphrase = private_key_passphrase ,
57
58
)
58
59
59
60
self ._create_or_update_connection (
60
61
current_connection = cli_context .connection_context .connection_name ,
61
62
connection_name = connection_name , # type: ignore[arg-type]
62
63
private_key_path = self ._get_private_key_path (
63
- output_path = output_path , key_name = connection_name # type: ignore[arg-type]
64
+ output_path = output_path , key_name = key_name # type: ignore[arg-type]
64
65
),
65
66
)
66
67
67
68
def rotate (
68
69
self ,
69
- connection_name : str ,
70
70
key_length : int ,
71
71
output_path : SecurePath ,
72
72
private_key_passphrase : SecretType ,
73
73
):
74
- # When the user provide new connection name
75
- if connection_name and connection_exists (connection_name ):
76
- raise ClickException (
77
- f"Connection with name { connection_name } already exists."
78
- )
79
-
80
74
cli_context = get_cli_context ()
81
- # When the use not provide connection name, so we overwrite the current connection
82
- if not connection_name :
83
- connection_name = cli_context .connection_context .connection_name
75
+ connection_name = cli_context .connection_context .connection_name
84
76
85
77
self ._ensure_connection_has_private_key (
86
78
cli_context .connection_context .connection_name
@@ -98,10 +90,11 @@ def rotate(
98
90
public_key_2 ,
99
91
)
100
92
93
+ key_name = AuthManager ._get_free_key_name (output_path , connection_name )
101
94
public_key = self ._generate_keys_and_return_public_key (
102
95
key_length = key_length ,
103
96
output_path = output_path ,
104
- key_name = connection_name ,
97
+ key_name = key_name ,
105
98
private_key_passphrase = private_key_passphrase ,
106
99
)
107
100
self .set_public_key (
@@ -111,7 +104,7 @@ def rotate(
111
104
current_connection = cli_context .connection_context .connection_name ,
112
105
connection_name = connection_name ,
113
106
private_key_path = self ._get_private_key_path (
114
- output_path = output_path , key_name = connection_name
107
+ output_path = output_path , key_name = key_name
115
108
),
116
109
)
117
110
@@ -120,7 +113,7 @@ def _generate_key_pair_and_set_public_key(
120
113
user : str ,
121
114
key_length : int ,
122
115
output_path : SecurePath ,
123
- connection_name : str ,
116
+ key_name : str ,
124
117
private_key_passphrase : SecretType ,
125
118
):
126
119
public_key_exists , public_key_2_exists = self ._get_public_keys ()
@@ -136,7 +129,7 @@ def _generate_key_pair_and_set_public_key(
136
129
public_key = self ._generate_keys_and_return_public_key (
137
130
key_length = key_length ,
138
131
output_path = output_path ,
139
- key_name = connection_name , # type: ignore[arg-type]
132
+ key_name = key_name , # type: ignore[arg-type]
140
133
private_key_passphrase = private_key_passphrase ,
141
134
)
142
135
self .set_public_key (user , PublicKeyProperty .RSA_PUBLIC_KEY , public_key )
@@ -192,19 +185,19 @@ def extend_connection_add(
192
185
output_path : SecurePath ,
193
186
private_key_passphrase : SecretType ,
194
187
) -> Dict :
188
+ key_name = AuthManager ._get_free_key_name (output_path , connection_name )
189
+
195
190
self ._generate_key_pair_and_set_public_key (
196
191
user = connection_options ["user" ],
197
192
key_length = key_length ,
198
193
output_path = output_path ,
199
- connection_name = connection_name ,
194
+ key_name = key_name ,
200
195
private_key_passphrase = private_key_passphrase ,
201
196
)
202
197
203
198
connection_options ["authenticator" ] = "SNOWFLAKE_JWT"
204
199
connection_options ["private_key_file" ] = str (
205
- self ._get_private_key_path (
206
- output_path = output_path , key_name = connection_name
207
- ).path
200
+ self ._get_private_key_path (output_path = output_path , key_name = key_name ).path
208
201
)
209
202
if connection_options .get ("password" ):
210
203
del connection_options ["password" ]
@@ -293,6 +286,30 @@ def _generate_keys_and_return_public_key(
293
286
294
287
return public_pem .decode ("utf-8" )
295
288
289
+ @staticmethod
290
+ def _get_free_key_name (output_path : SecurePath , key_name : str ) -> str :
291
+ new_private_key = f"{ key_name } .p8"
292
+ new_public_key = f"{ key_name } .pub"
293
+ new_key_name = key_name
294
+ counter = 1
295
+
296
+ while (
297
+ (output_path / new_private_key ).exists ()
298
+ and (output_path / new_public_key ).exists ()
299
+ and counter <= 100
300
+ ):
301
+ new_key_name = f"{ key_name } _{ counter } "
302
+ new_private_key = f"{ new_key_name } .p8"
303
+ new_public_key = f"{ new_key_name } .pub"
304
+ counter += 1
305
+
306
+ if counter == 100 :
307
+ raise ClickException (
308
+ "Too many key pairs with the same name in the output directory."
309
+ )
310
+
311
+ return new_key_name
312
+
296
313
@staticmethod
297
314
def _get_private_key_path (output_path : SecurePath , key_name : str ) -> SecurePath :
298
315
return (output_path / f"{ key_name } .p8" ).resolve ()
0 commit comments