55
66import {
77 bearerCredentialsUpdateRequestType ,
8+ iamCredentialsUpdateRequestType ,
89 ConnectionMetadata ,
910 NotificationType ,
1011 RequestType ,
@@ -17,8 +18,8 @@ import { LanguageClient } from 'vscode-languageclient'
1718import { AuthUtil } from 'aws-core-vscode/codewhisperer'
1819import { Writable } from 'stream'
1920import { onceChanged } from 'aws-core-vscode/utils'
20- import { getLogger , oneMinute } from 'aws-core-vscode/shared'
21- import { isSsoConnection } from 'aws-core-vscode/auth'
21+ import { getLogger , oneMinute , isSageMaker } from 'aws-core-vscode/shared'
22+ import { isSsoConnection , isIamConnection } from 'aws-core-vscode/auth'
2223
2324export const encryptionKey = crypto . randomBytes ( 32 )
2425
@@ -78,10 +79,16 @@ export class AmazonQLspAuth {
7879 */
7980 async refreshConnection ( force : boolean = false ) {
8081 const activeConnection = this . authUtil . conn
81- if ( this . authUtil . isConnectionValid ( ) && isSsoConnection ( activeConnection ) ) {
82- // send the token to the language server
83- const token = await this . authUtil . getBearerToken ( )
84- await ( force ? this . _updateBearerToken ( token ) : this . updateBearerToken ( token ) )
82+ if ( this . authUtil . isConnectionValid ( ) ) {
83+ if ( isSsoConnection ( activeConnection ) ) {
84+ // Existing SSO path
85+ const token = await this . authUtil . getBearerToken ( )
86+ await ( force ? this . _updateBearerToken ( token ) : this . updateBearerToken ( token ) )
87+ } else if ( isSageMaker ( ) && isIamConnection ( activeConnection ) ) {
88+ // New SageMaker IAM path
89+ const credentials = await this . authUtil . getCredentials ( )
90+ await ( force ? this . _updateIamCredentials ( credentials ) : this . updateIamCredentials ( credentials ) )
91+ }
8592 }
8693 }
8794
@@ -92,9 +99,7 @@ export class AmazonQLspAuth {
9299
93100 public updateBearerToken = onceChanged ( this . _updateBearerToken . bind ( this ) )
94101 private async _updateBearerToken ( token : string ) {
95- const request = await this . createUpdateCredentialsRequest ( {
96- token,
97- } )
102+ const request = await this . createUpdateBearerCredentialsRequest ( token )
98103
99104 // "aws/credentials/token/update"
100105 // https://github.com/aws/language-servers/blob/44d81f0b5754747d77bda60b40cc70950413a737/core/aws-lsp-core/src/credentials/credentialsProvider.ts#L27
@@ -103,15 +108,36 @@ export class AmazonQLspAuth {
103108 this . client . info ( `UpdateBearerToken: ${ JSON . stringify ( request ) } ` )
104109 }
105110
111+ public updateIamCredentials = onceChanged ( this . _updateIamCredentials . bind ( this ) )
112+ private async _updateIamCredentials ( credentials : any ) {
113+ getLogger ( ) . info (
114+ `[SageMaker Debug] Updating IAM credentials - credentials received: ${ credentials ? 'YES' : 'NO' } `
115+ )
116+ if ( credentials ) {
117+ getLogger ( ) . info (
118+ `[SageMaker Debug] IAM credentials structure: accessKeyId=${ credentials . accessKeyId ? 'present' : 'missing' } , secretAccessKey=${ credentials . secretAccessKey ? 'present' : 'missing' } , sessionToken=${ credentials . sessionToken ? 'present' : 'missing' } `
119+ )
120+ }
121+
122+ const request = await this . createUpdateIamCredentialsRequest ( credentials )
123+
124+ // "aws/credentials/iam/update"
125+ await this . client . sendRequest ( iamCredentialsUpdateRequestType . method , request )
126+
127+ this . client . info ( `UpdateIamCredentials: ${ JSON . stringify ( request ) } ` )
128+ getLogger ( ) . info ( `[SageMaker Debug] IAM credentials update request sent successfully` )
129+ }
130+
106131 public startTokenRefreshInterval ( pollingTime : number = oneMinute / 2 ) {
107132 const interval = setInterval ( async ( ) => {
108133 await this . refreshConnection ( ) . catch ( ( e ) => this . logRefreshError ( e ) )
109134 } , pollingTime )
110135 return interval
111136 }
112137
113- private async createUpdateCredentialsRequest ( data : any ) : Promise < UpdateCredentialsParams > {
114- const payload = new TextEncoder ( ) . encode ( JSON . stringify ( { data } ) )
138+ private async createUpdateBearerCredentialsRequest ( token : string ) : Promise < UpdateCredentialsParams > {
139+ const bearerCredentials = { token }
140+ const payload = new TextEncoder ( ) . encode ( JSON . stringify ( { data : bearerCredentials } ) )
115141
116142 const jwt = await new jose . CompactEncrypt ( payload )
117143 . setProtectedHeader ( { alg : 'dir' , enc : 'A256GCM' } )
@@ -127,4 +153,24 @@ export class AmazonQLspAuth {
127153 encrypted : true ,
128154 }
129155 }
156+
157+ private async createUpdateIamCredentialsRequest ( credentials : any ) : Promise < UpdateCredentialsParams > {
158+ // Extract IAM credentials structure
159+ const iamCredentials = {
160+ accessKeyId : credentials . accessKeyId ,
161+ secretAccessKey : credentials . secretAccessKey ,
162+ sessionToken : credentials . sessionToken ,
163+ }
164+ const payload = new TextEncoder ( ) . encode ( JSON . stringify ( { data : iamCredentials } ) )
165+
166+ const jwt = await new jose . CompactEncrypt ( payload )
167+ . setProtectedHeader ( { alg : 'dir' , enc : 'A256GCM' } )
168+ . encrypt ( encryptionKey )
169+
170+ return {
171+ data : jwt ,
172+ // Omit metadata for IAM credentials since startUrl is undefined for non-SSO connections
173+ encrypted : true ,
174+ }
175+ }
130176}
0 commit comments