diff --git a/api/routes/device.go b/api/routes/device.go index a364583f097..7eeeb0a0da0 100644 --- a/api/routes/device.go +++ b/api/routes/device.go @@ -89,7 +89,12 @@ func (h *Handler) GetDeviceList(c gateway.Context) error { return err } - res, count, err := h.service.ListDevices(c.Ctx(), req) + var tenant string + if c.Tenant() != nil { + tenant = c.Tenant().ID + } + + res, count, err := h.service.ListDevices(c.Ctx(), tenant, req) c.Response().Header().Set("X-Total-Count", strconv.Itoa(count)) if err != nil { diff --git a/api/services/device.go b/api/services/device.go index 3e0dfe4a35e..94cd57c346c 100644 --- a/api/services/device.go +++ b/api/services/device.go @@ -16,7 +16,7 @@ import ( const StatusAccepted = "accepted" type DeviceService interface { - ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error) + ListDevices(ctx context.Context, tenant string, req *requests.DeviceList) ([]models.Device, int, error) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) // ResolveDevice attempts to resolve a device by searching for either its UID or hostname. When both are provided, @@ -34,10 +34,13 @@ type DeviceService interface { UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) error } -func (s *service) ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error) { +func (s *service) ListDevices(ctx context.Context, tenant string, req *requests.DeviceList) ([]models.Device, int, error) { + if tenant == "" { + return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, false) + } + if req.DeviceStatus == models.DeviceStatusRemoved { - // TODO: unique DeviceList - removed, count, err := s.store.DeviceRemovedList(ctx, req.TenantID, req.Paginator, req.Filters, req.Sorter) + removed, count, err := s.store.DeviceRemovedList(ctx, tenant, req.Paginator, req.Filters, req.Sorter) if err != nil { return nil, 0, err } @@ -50,34 +53,32 @@ func (s *service) ListDevices(ctx context.Context, req *requests.DeviceList) ([] return devices, count, nil } - if req.TenantID != "" { - ns, err := s.store.NamespaceGet(ctx, req.TenantID) - if err != nil { - return nil, 0, NewErrNamespaceNotFound(req.TenantID, err) - } + ns, err := s.store.NamespaceGet(ctx, tenant) + if err != nil { + return nil, 0, NewErrNamespaceNotFound(tenant, err) + } - if ns.HasMaxDevices() { - switch { - case envs.IsCloud(): - removed, err := s.store.DeviceRemovedCount(ctx, ns.TenantID) - if err != nil { - return nil, 0, NewErrDeviceRemovedCount(err) - } + var limitReached bool - if ns.HasLimitDevicesReached(removed) { - return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, store.DeviceAcceptableFromRemoved) - } - case envs.IsEnterprise(): - fallthrough - case envs.IsCommunity(): - if ns.HasMaxDevicesReached() { - return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, store.DeviceAcceptableAsFalse) - } + if ns.HasMaxDevices() { + switch { + case envs.IsCloud(): + removed, err := s.store.DeviceRemovedCount(ctx, ns.TenantID) + if err != nil { + return nil, 0, NewErrDeviceRemovedCount(err) + } + + if ns.HasLimitDevicesReached(removed) { + limitReached = true + } + case envs.IsEnterprise() || envs.IsCommunity(): + if ns.HasMaxDevicesReached() { + limitReached = true } } } - return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, store.DeviceAcceptableIfNotAccepted) + return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, limitReached) } func (s *service) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) { diff --git a/api/services/mocks/services.go b/api/services/mocks/services.go index 52ad6ca0254..7ef1ddd3ea1 100644 --- a/api/services/mocks/services.go +++ b/api/services/mocks/services.go @@ -1104,9 +1104,9 @@ func (_m *Service) ListAPIKeys(ctx context.Context, req *requests.ListAPIKey) ([ return r0, r1, r2 } -// ListDevices provides a mock function with given fields: ctx, req -func (_m *Service) ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error) { - ret := _m.Called(ctx, req) +// ListDevices provides a mock function with given fields: ctx, tenant, req +func (_m *Service) ListDevices(ctx context.Context, tenant string, req *requests.DeviceList) ([]models.Device, int, error) { + ret := _m.Called(ctx, tenant, req) if len(ret) == 0 { panic("no return value specified for ListDevices") @@ -1115,25 +1115,25 @@ func (_m *Service) ListDevices(ctx context.Context, req *requests.DeviceList) ([ var r0 []models.Device var r1 int var r2 error - if rf, ok := ret.Get(0).(func(context.Context, *requests.DeviceList) ([]models.Device, int, error)); ok { - return rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, string, *requests.DeviceList) ([]models.Device, int, error)); ok { + return rf(ctx, tenant, req) } - if rf, ok := ret.Get(0).(func(context.Context, *requests.DeviceList) []models.Device); ok { - r0 = rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, string, *requests.DeviceList) []models.Device); ok { + r0 = rf(ctx, tenant, req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]models.Device) } } - if rf, ok := ret.Get(1).(func(context.Context, *requests.DeviceList) int); ok { - r1 = rf(ctx, req) + if rf, ok := ret.Get(1).(func(context.Context, string, *requests.DeviceList) int); ok { + r1 = rf(ctx, tenant, req) } else { r1 = ret.Get(1).(int) } - if rf, ok := ret.Get(2).(func(context.Context, *requests.DeviceList) error); ok { - r2 = rf(ctx, req) + if rf, ok := ret.Get(2).(func(context.Context, string, *requests.DeviceList) error); ok { + r2 = rf(ctx, tenant, req) } else { r2 = ret.Error(2) } diff --git a/api/store/device.go b/api/store/device.go index c157b54aa6b..aa0834ad5f0 100644 --- a/api/store/device.go +++ b/api/store/device.go @@ -7,18 +7,6 @@ import ( "github.com/shellhub-io/shellhub/pkg/models" ) -type DeviceAcceptable uint - -const ( - // DeviceAcceptableIfNotAccepted is used to indicate the all devices not accepted will be defined as "acceptabled". - DeviceAcceptableIfNotAccepted DeviceAcceptable = iota + 1 - // DeviceAcceptableFromRemoved is used to indicate that the namepsace's device maxium number of devices has been - // reached and should set the "acceptable" value to true for devices that were recently removed. - DeviceAcceptableFromRemoved - // DeviceAcceptableAsFalse set acceptable to false to all returned devices. - DeviceAcceptableAsFalse -) - type DeviceResolver uint const ( @@ -28,7 +16,7 @@ const ( ) type DeviceStore interface { - DeviceList(ctx context.Context, status models.DeviceStatus, pagination query.Paginator, filters query.Filters, sorter query.Sorter, acceptable DeviceAcceptable) ([]models.Device, int, error) + DeviceList(ctx context.Context, status models.DeviceStatus, pagination query.Paginator, filters query.Filters, sorter query.Sorter, full bool) ([]models.Device, int, error) // DeviceResolve fetches a device using a specific resolver within a given tenant ID. // diff --git a/api/store/mocks/store.go b/api/store/mocks/store.go index 0a2bf5c9aeb..63fb9a6def8 100644 --- a/api/store/mocks/store.go +++ b/api/store/mocks/store.go @@ -438,9 +438,9 @@ func (_m *Store) DeviceGetTags(ctx context.Context, tenant string) ([]string, in return r0, r1, r2 } -// DeviceList provides a mock function with given fields: ctx, status, pagination, filters, sorter, acceptable -func (_m *Store) DeviceList(ctx context.Context, status models.DeviceStatus, pagination query.Paginator, filters query.Filters, sorter query.Sorter, acceptable store.DeviceAcceptable) ([]models.Device, int, error) { - ret := _m.Called(ctx, status, pagination, filters, sorter, acceptable) +// DeviceList provides a mock function with given fields: ctx, status, pagination, filters, sorter, full +func (_m *Store) DeviceList(ctx context.Context, status models.DeviceStatus, pagination query.Paginator, filters query.Filters, sorter query.Sorter, full bool) ([]models.Device, int, error) { + ret := _m.Called(ctx, status, pagination, filters, sorter, full) if len(ret) == 0 { panic("no return value specified for DeviceList") @@ -449,25 +449,25 @@ func (_m *Store) DeviceList(ctx context.Context, status models.DeviceStatus, pag var r0 []models.Device var r1 int var r2 error - if rf, ok := ret.Get(0).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, store.DeviceAcceptable) ([]models.Device, int, error)); ok { - return rf(ctx, status, pagination, filters, sorter, acceptable) + if rf, ok := ret.Get(0).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, bool) ([]models.Device, int, error)); ok { + return rf(ctx, status, pagination, filters, sorter, full) } - if rf, ok := ret.Get(0).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, store.DeviceAcceptable) []models.Device); ok { - r0 = rf(ctx, status, pagination, filters, sorter, acceptable) + if rf, ok := ret.Get(0).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, bool) []models.Device); ok { + r0 = rf(ctx, status, pagination, filters, sorter, full) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]models.Device) } } - if rf, ok := ret.Get(1).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, store.DeviceAcceptable) int); ok { - r1 = rf(ctx, status, pagination, filters, sorter, acceptable) + if rf, ok := ret.Get(1).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, bool) int); ok { + r1 = rf(ctx, status, pagination, filters, sorter, full) } else { r1 = ret.Get(1).(int) } - if rf, ok := ret.Get(2).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, store.DeviceAcceptable) error); ok { - r2 = rf(ctx, status, pagination, filters, sorter, acceptable) + if rf, ok := ret.Get(2).(func(context.Context, models.DeviceStatus, query.Paginator, query.Filters, query.Sorter, bool) error); ok { + r2 = rf(ctx, status, pagination, filters, sorter, full) } else { r2 = ret.Error(2) } diff --git a/api/store/mongo/device.go b/api/store/mongo/device.go index ea5f875ff72..a50a2aa4760 100644 --- a/api/store/mongo/device.go +++ b/api/store/mongo/device.go @@ -19,7 +19,7 @@ import ( ) // DeviceList returns a list of devices based on the given filters, pagination and sorting. -func (s *Store) DeviceList(ctx context.Context, status models.DeviceStatus, paginator query.Paginator, filters query.Filters, sorter query.Sorter, acceptable store.DeviceAcceptable) ([]models.Device, int, error) { +func (s *Store) DeviceList(ctx context.Context, status models.DeviceStatus, paginator query.Paginator, filters query.Filters, sorter query.Sorter, full bool) ([]models.Device, int, error) { query := []bson.M{ { "$match": bson.M{ @@ -63,58 +63,36 @@ func (s *Store) DeviceList(ctx context.Context, status models.DeviceStatus, pagi }}, query...) } - // When the listing mode is [store.DeviceListModeMaxDeviceReached], we should evaluate the `removed_devices` - // collection to check its `accetable` status. - switch acceptable { - case store.DeviceAcceptableFromRemoved: - query = append(query, []bson.M{ - { - "$lookup": bson.M{ - "from": "removed_devices", - "localField": "uid", - "foreignField": "device.uid", - "as": "removed", - }, - }, - { - "$addFields": bson.M{ - "acceptable": bson.M{ - "$cond": bson.M{ - "if": bson.M{ - "$and": bson.A{ - bson.M{"$ne": bson.A{"$status", models.DeviceStatusAccepted}}, - bson.M{"$anyElementTrue": []interface{}{"$removed"}}, - }, - }, - "then": true, - "else": false, - }, - }, - }, - }, - { - "$unset": "removed", - }, - }...) - case store.DeviceAcceptableAsFalse: - query = append(query, bson.M{ - "$addFields": bson.M{ - "acceptable": false, + query = append(query, []bson.M{ + { + "$lookup": bson.M{ + "from": "removed_devices", + "localField": "uid", + "foreignField": "device.uid", + "as": "removed", }, - }) - case store.DeviceAcceptableIfNotAccepted: - query = append(query, bson.M{ + }, + { "$addFields": bson.M{ "acceptable": bson.M{ "$cond": bson.M{ - "if": bson.M{"$ne": bson.A{"$status", models.DeviceStatusAccepted}}, - "then": true, + "if": bson.M{"$ne": bson.A{"$status", models.DeviceStatusAccepted}}, + "then": bson.M{ + "$cond": bson.M{ + "if": bson.M{"$anyElementTrue": []any{"$removed"}}, + "then": true, + "else": bson.M{"$not": full}, + }, + }, "else": false, }, }, }, - }) - } + }, + { + "$unset": "removed", + }, + }...) queryMatch, err := queries.FromFilters(&filters) if err != nil { diff --git a/api/store/mongo/namespace.go b/api/store/mongo/namespace.go index e9b59b4d27e..175b84ecab5 100644 --- a/api/store/mongo/namespace.go +++ b/api/store/mongo/namespace.go @@ -414,5 +414,9 @@ func (s *Store) NamespaceIncrementDeviceCount(ctx context.Context, tenantID stri return store.ErrNoDocuments } + if err := s.cache.Delete(ctx, strings.Join([]string{"namespace", tenantID}, "/")); err != nil { + log.Error(err) + } + return nil } diff --git a/pkg/api/requests/device.go b/pkg/api/requests/device.go index 0e9ee014b37..cb50cd6d52e 100644 --- a/pkg/api/requests/device.go +++ b/pkg/api/requests/device.go @@ -6,7 +6,6 @@ import ( ) type DeviceList struct { - TenantID string `header:"X-Tenant-ID"` DeviceStatus models.DeviceStatus `query:"status"` // TODO: validate query.Paginator query.Sorter