Skip to content

Commit 0e74bbf

Browse files
committed
fix tmds creds response
1 parent 966c44d commit 0e74bbf

File tree

13 files changed

+218
-8
lines changed

13 files changed

+218
-8
lines changed

agent/api/task/task.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ type Task struct {
245245
// perform some action at the task level, such as pulling image from ECR
246246
ExecutionCredentialsID string `json:"executionCredentialsID"`
247247

248-
// credentialsID is used to set the CredentialsId field for the
248+
// CredentialsID is used to set the CredentialsId field for the
249249
// IAMRoleCredentials object associated with the task. This id can be
250250
// used to look up the credentials for task in the credentials manager
251-
credentialsID string
251+
CredentialsID string `json:"credentialsID"`
252252
credentialsRelativeURIUnsafe string
253253

254254
// ENIs is the list of Elastic Network Interfaces assigned to this task. The
@@ -2805,15 +2805,15 @@ func (task *Task) SetCredentialsID(id string) {
28052805
task.lock.Lock()
28062806
defer task.lock.Unlock()
28072807

2808-
task.credentialsID = id
2808+
task.CredentialsID = id
28092809
}
28102810

28112811
// GetCredentialsID gets the credentials ID for the task
28122812
func (task *Task) GetCredentialsID() string {
28132813
task.lock.RLock()
28142814
defer task.lock.RUnlock()
28152815

2816-
return task.credentialsID
2816+
return task.CredentialsID
28172817
}
28182818

28192819
// SetCredentialsRelativeURI sets the credentials relative uri for the task

agent/engine/data.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ func (engine *DockerTaskEngine) loadTasks() error {
5656
for _, task := range tasks {
5757
engine.state.AddTask(task)
5858

59+
// Register the task role's ID in credentials manager as soon as possible.
60+
// This ensures TMDS can distinguish between invalid credentials requests (400) and
61+
// known credentials that aren't available yet (503) after agent restart.
62+
if credentialsID := task.GetCredentialsID(); credentialsID != "" {
63+
engine.credentialsManager.AddKnownCredentialsID(credentialsID)
64+
}
65+
5966
// TODO: Will need to clean up all of the STOPPED managed daemon tasks
6067
md, ok := task.IsManagedDaemonTask()
6168
if ok {

agent/engine/data_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import (
2727
"github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment"
2828
apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status"
2929
apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status"
30+
mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks"
3031
ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface"
3132

33+
"github.com/golang/mock/gomock"
3234
"github.com/stretchr/testify/assert"
3335
"github.com/stretchr/testify/require"
3436
)
@@ -376,3 +378,40 @@ func TestRemoveENIAttachmentData(t *testing.T) {
376378
require.NoError(t, err)
377379
assert.Len(t, res, 0)
378380
}
381+
382+
func TestLoadTasksRegistersCredentialsID(t *testing.T) {
383+
ctrl := gomock.NewController(t)
384+
defer ctrl.Finish()
385+
386+
dataClient := newTestDataClient(t)
387+
mockCredentialsManager := mock_credentials.NewMockManager(ctrl)
388+
389+
// Create a task with credentials ID
390+
testCredentialsID := "test-credentials-id"
391+
taskWithCredentials := &apitask.Task{
392+
Arn: testTaskARN,
393+
Containers: []*apicontainer.Container{testContainer},
394+
}
395+
// Set the credentials ID after task creation
396+
taskWithCredentials.SetCredentialsID(testCredentialsID)
397+
398+
// Save task to data client
399+
require.NoError(t, dataClient.SaveTask(taskWithCredentials))
400+
401+
engine := &DockerTaskEngine{
402+
state: dockerstate.NewTaskEngineState(),
403+
dataClient: dataClient,
404+
credentialsManager: mockCredentialsManager,
405+
}
406+
407+
// Expect credentials ID to be added to known credentials during loadTasks
408+
mockCredentialsManager.EXPECT().AddKnownCredentialsID(testCredentialsID)
409+
410+
// Call loadTasks and verify credentials ID is registered
411+
require.NoError(t, engine.loadTasks())
412+
413+
// Verify task was loaded into state
414+
task, ok := engine.state.TaskByArn(testTaskARN)
415+
assert.True(t, ok)
416+
assert.Equal(t, testCredentialsID, task.GetCredentialsID())
417+
}

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/credentials/interface.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/credentials/manager.go

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks/credentials_mocks.go

Lines changed: 26 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1/credentials_handler.go

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecs-agent/credentials/interface.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ type Manager interface {
2020
SetTaskCredentials(*TaskIAMRoleCredentials) error
2121
GetTaskCredentials(string) (TaskIAMRoleCredentials, bool)
2222
RemoveCredentials(string)
23+
IsCredentialsPending(string) bool
24+
AddKnownCredentialsID(string)
2325
}

ecs-agent/credentials/manager.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ func (roleCredentials *IAMRoleCredentials) GenerateCredentialsEndpointRelativeUR
8686
// the credentials endpoint
8787
type credentialsManager struct {
8888
// idToTaskCredentials maps credentials id to its corresponding TaskIAMRoleCredentials object
89+
// this map contains credentials id for which agent received the actual credentials from ACS already
8990
idToTaskCredentials map[string]TaskIAMRoleCredentials
91+
// knownCredentialsIDs tracks all credentials IDs we know about
92+
knownCredentialsIDs map[string]bool
9093
taskCredentialsLock sync.RWMutex
9194
}
9295

@@ -108,6 +111,7 @@ func IAMRoleCredentialsFromACS(roleCredentials *ecsacs.IAMRoleCredentials, roleT
108111
func NewManager() Manager {
109112
return &credentialsManager{
110113
idToTaskCredentials: make(map[string]TaskIAMRoleCredentials),
114+
knownCredentialsIDs: make(map[string]bool),
111115
}
112116
}
113117

@@ -132,6 +136,8 @@ func (manager *credentialsManager) SetTaskCredentials(taskCredentials *TaskIAMRo
132136
IAMRoleCredentials: taskCredentials.GetIAMRoleCredentials(),
133137
}
134138

139+
manager.knownCredentialsIDs[credentials.CredentialsID] = true
140+
135141
return nil
136142
}
137143

@@ -157,4 +163,26 @@ func (manager *credentialsManager) RemoveCredentials(id string) {
157163
defer manager.taskCredentialsLock.Unlock()
158164

159165
delete(manager.idToTaskCredentials, id)
166+
delete(manager.knownCredentialsIDs, id)
167+
}
168+
169+
// IsCredentialsPending returns true if credentials ID is known but has not yet arrived from ACS.
170+
func (manager *credentialsManager) IsCredentialsPending(id string) bool {
171+
manager.taskCredentialsLock.RLock()
172+
defer manager.taskCredentialsLock.RUnlock()
173+
174+
_, isKnown := manager.knownCredentialsIDs[id]
175+
_, hasCredentials := manager.idToTaskCredentials[id]
176+
177+
return isKnown && !hasCredentials
178+
}
179+
180+
// AddKnownCredentialsID adds a credentials ID to the known set.
181+
// This is useful when agent needs to track known credentials IDs
182+
// for which the credentials themselves have not arrived from ACS.
183+
func (manager *credentialsManager) AddKnownCredentialsID(id string) {
184+
manager.taskCredentialsLock.Lock()
185+
defer manager.taskCredentialsLock.Unlock()
186+
187+
manager.knownCredentialsIDs[id] = true
160188
}

ecs-agent/credentials/manager_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,54 @@ func TestRemoveExistingCredentials(t *testing.T) {
172172
t.Error("Expected GetTaskCredentials to return false for removed credentials")
173173
}
174174
}
175+
176+
// TestAddKnownCredentialsID tests that AddKnownCredentialsID properly tracks credentials IDs
177+
func TestAddKnownCredentialsID(t *testing.T) {
178+
manager := NewManager()
179+
credentialsID := "test-creds-id"
180+
181+
// Initially, credentials should not be pending
182+
assert.False(t, manager.IsCredentialsPending(credentialsID))
183+
184+
// Add known credentials ID
185+
manager.AddKnownCredentialsID(credentialsID)
186+
187+
// Now it should be pending (known but no actual credentials)
188+
assert.True(t, manager.IsCredentialsPending(credentialsID))
189+
190+
// Verify no actual credentials exist
191+
_, ok := manager.GetTaskCredentials(credentialsID)
192+
assert.False(t, ok)
193+
}
194+
195+
// TestIsCredentialsPending tests the IsCredentialsPending method behavior
196+
func TestIsCredentialsPending(t *testing.T) {
197+
manager := NewManager()
198+
credentialsID := "test-creds-id"
199+
200+
// Case 1: Unknown credentials ID - should return false
201+
assert.False(t, manager.IsCredentialsPending(credentialsID))
202+
203+
// Case 2: Known but no actual credentials - should return true
204+
manager.AddKnownCredentialsID(credentialsID)
205+
assert.True(t, manager.IsCredentialsPending(credentialsID))
206+
207+
// Case 3: Known and has actual credentials - should return false
208+
credentials := TaskIAMRoleCredentials{
209+
ARN: "arn:aws:ecs:us-east-1:123456789012:task/test-task",
210+
IAMRoleCredentials: IAMRoleCredentials{
211+
AccessKeyID: "akid1",
212+
SecretAccessKey: "skid1",
213+
SessionToken: "stkn",
214+
Expiration: "ts",
215+
CredentialsID: credentialsID,
216+
},
217+
}
218+
err := manager.SetTaskCredentials(&credentials)
219+
assert.NoError(t, err)
220+
assert.False(t, manager.IsCredentialsPending(credentialsID))
221+
222+
// Case 4: After removal - should return false
223+
manager.RemoveCredentials(credentialsID)
224+
assert.False(t, manager.IsCredentialsPending(credentialsID))
225+
}

0 commit comments

Comments
 (0)