@@ -7,12 +7,14 @@ import (
77 "errors"
88 "fmt"
99 "strings"
10-
11- "github.com/raystack/frontier/internal/bootstrap/schema"
10+ "time"
1211
1312 "github.com/doug-martin/goqu/v9"
13+ "github.com/jmoiron/sqlx"
1414 "github.com/raystack/frontier/core/namespace"
1515 "github.com/raystack/frontier/core/policy"
16+ "github.com/raystack/frontier/internal/bootstrap/schema"
17+ "github.com/raystack/frontier/pkg/auditrecord"
1618 "github.com/raystack/frontier/pkg/db"
1719)
1820
@@ -199,15 +201,42 @@ func (r PolicyRepository) Upsert(ctx context.Context, pol policy.Policy) (policy
199201 "principal_type" : pol .PrincipalType ,
200202 "metadata" : marshaledMetadata ,
201203 }).OnConflict (goqu .DoUpdate ("role_id, resource_id, resource_type, principal_id, principal_type" , goqu.Record {
202- "metadata" : marshaledMetadata ,
204+ "metadata" : marshaledMetadata ,
205+ "updated_at" : goqu .L ("now()" ),
203206 })).Returning (& PolicyCols {}).ToSQL ()
204207 if err != nil {
205208 return policy.Policy {}, fmt .Errorf ("%w: %w" , queryErr , err )
206209 }
207210
211+ // Check if policy exists before upsert
212+ _ , exists := r .getPolicyByConstraint (ctx , pol )
213+
208214 var policyDB Policy
209- if err = r .dbc .WithTimeout (ctx , TABLE_POLICIES , "Upsert" , func (ctx context.Context ) error {
210- return r .dbc .QueryRowxContext (ctx , query , params ... ).StructScan (& policyDB )
215+ if err = r .dbc .WithTxn (ctx , sql.TxOptions {}, func (tx * sqlx.Tx ) error {
216+ return r .dbc .WithTimeout (ctx , TABLE_POLICIES , "Upsert" , func (ctx context.Context ) error {
217+ if err := tx .QueryRowxContext (ctx , query , params ... ).StructScan (& policyDB ); err != nil {
218+ return err
219+ }
220+
221+ var (
222+ event auditrecord.Event
223+ timestamp time.Time
224+ additionalMetadata map [string ]any
225+ )
226+ if exists {
227+ event = auditrecord .PolicyUpdatedEvent
228+ timestamp = policyDB .UpdatedAt
229+ additionalMetadata = map [string ]any {
230+ "updated_metadata" : pol .Metadata ,
231+ }
232+ } else {
233+ event = auditrecord .PolicyCreatedEvent
234+ timestamp = policyDB .CreatedAt
235+ }
236+
237+ auditRecord := r .buildPolicyAuditRecord (ctx , tx , event , policyDB , timestamp , additionalMetadata )
238+ return InsertAuditRecordInTx (ctx , tx , auditRecord )
239+ })
211240 }); err != nil {
212241 err = checkPostgresError (err )
213242 switch {
@@ -225,6 +254,13 @@ func (r PolicyRepository) Update(ctx context.Context, toUpdate policy.Policy) (s
225254 if strings .TrimSpace (toUpdate .ID ) == "" {
226255 return "" , policy .ErrInvalidID
227256 }
257+
258+ // Fetch existing policy for audit record
259+ existingPolicy , err := r .Get (ctx , toUpdate .ID )
260+ if err != nil {
261+ return "" , err
262+ }
263+
228264 marshaledMetadata , err := json .Marshal (toUpdate .Metadata )
229265 if err != nil {
230266 return "" , fmt .Errorf ("%w: %s" , parseErr , err )
@@ -236,14 +272,32 @@ func (r PolicyRepository) Update(ctx context.Context, toUpdate policy.Policy) (s
236272 "updated_at" : goqu .L ("now()" ),
237273 }).Where (goqu.Ex {
238274 "id" : toUpdate .ID ,
239- }).Returning ("id" ).ToSQL ()
275+ }).Returning ("id" , "updated_at" ).ToSQL ()
240276 if err != nil {
241277 return "" , fmt .Errorf ("%w: %s" , queryErr , err )
242278 }
243279
244280 var policyID string
245- if err = r .dbc .WithTimeout (ctx , TABLE_POLICIES , "Update" , func (ctx context.Context ) error {
246- return r .dbc .QueryRowxContext (ctx , query , params ... ).Scan (& policyID )
281+ var updatedAt time.Time
282+ if err = r .dbc .WithTxn (ctx , sql.TxOptions {}, func (tx * sqlx.Tx ) error {
283+ return r .dbc .WithTimeout (ctx , TABLE_POLICIES , "Update" , func (ctx context.Context ) error {
284+ if err := tx .QueryRowxContext (ctx , query , params ... ).Scan (& policyID , & updatedAt ); err != nil {
285+ return err
286+ }
287+
288+ policyDB := Policy {
289+ ID : existingPolicy .ID ,
290+ RoleID : existingPolicy .RoleID ,
291+ ResourceID : existingPolicy .ResourceID ,
292+ ResourceType : existingPolicy .ResourceType ,
293+ PrincipalID : existingPolicy .PrincipalID ,
294+ PrincipalType : existingPolicy .PrincipalType ,
295+ }
296+ auditRecord := r .buildPolicyAuditRecord (ctx , tx , auditrecord .PolicyUpdatedEvent , policyDB , updatedAt , map [string ]any {
297+ "updated_metadata" : toUpdate .Metadata ,
298+ })
299+ return InsertAuditRecordInTx (ctx , tx , auditRecord )
300+ })
247301 }); err != nil {
248302 err = checkPostgresError (err )
249303 switch {
@@ -264,20 +318,35 @@ func (r PolicyRepository) Update(ctx context.Context, toUpdate policy.Policy) (s
264318}
265319
266320func (r PolicyRepository ) Delete (ctx context.Context , id string ) error {
267- query , params , err := dialect .Delete (TABLE_POLICIES ).Where (
268- goqu.Ex {
269- "id" : id ,
270- },
271- ).ToSQL ()
321+ // Fetch policy for audit record
322+ existingPolicy , err := r .Get (ctx , id )
272323 if err != nil {
273- return fmt .Errorf ("%w: %s" , queryErr , err )
274- }
275-
276- if err = r .dbc .WithTimeout (ctx , TABLE_POLICIES , "Delete" , func (ctx context.Context ) error {
277- if _ , err = r .dbc .DB .ExecContext (ctx , query , params ... ); err != nil {
278- return err
279- }
280- return nil
324+ return err
325+ }
326+
327+ if err := r .dbc .WithTxn (ctx , sql.TxOptions {}, func (tx * sqlx.Tx ) error {
328+ return r .dbc .WithTimeout (ctx , TABLE_POLICIES , "Delete" , func (ctx context.Context ) error {
329+ deleteQuery , deleteParams , err := dialect .Delete (TABLE_POLICIES ).
330+ Where (goqu.Ex {"id" : id }).
331+ ToSQL ()
332+ if err != nil {
333+ return fmt .Errorf ("%w: %w" , queryErr , err )
334+ }
335+ if _ , err := tx .ExecContext (ctx , deleteQuery , deleteParams ... ); err != nil {
336+ return err
337+ }
338+
339+ policyDB := Policy {
340+ ID : existingPolicy .ID ,
341+ RoleID : existingPolicy .RoleID ,
342+ ResourceID : existingPolicy .ResourceID ,
343+ ResourceType : existingPolicy .ResourceType ,
344+ PrincipalID : existingPolicy .PrincipalID ,
345+ PrincipalType : existingPolicy .PrincipalType ,
346+ }
347+ auditRecord := r .buildPolicyAuditRecord (ctx , tx , auditrecord .PolicyDeletedEvent , policyDB , time .Now (), nil )
348+ return InsertAuditRecordInTx (ctx , tx , auditRecord )
349+ })
281350 }); err != nil {
282351 err = checkPostgresError (err )
283352 switch {
@@ -398,3 +467,98 @@ func (r PolicyRepository) OrgMemberCount(ctx context.Context, id string) (policy
398467
399468 return result , nil
400469}
470+
471+ // buildPolicyAuditRecord builds an audit record for policy events
472+ func (r PolicyRepository ) buildPolicyAuditRecord (ctx context.Context , tx * sqlx.Tx , event auditrecord.Event , pol Policy , timestamp time.Time , additionalMetadata map [string ]any ) AuditRecord {
473+ orgID , resourceName := r .getResourceInfo (ctx , tx , pol .ResourceType , pol .ResourceID )
474+
475+ targetMetadata := map [string ]any {
476+ "role_id" : pol .RoleID ,
477+ "principal_id" : pol .PrincipalID ,
478+ "principal_type" : pol .PrincipalType ,
479+ }
480+ for k , v := range additionalMetadata {
481+ targetMetadata [k ] = v
482+ }
483+
484+ return BuildAuditRecord (
485+ ctx ,
486+ event ,
487+ AuditResource {
488+ ID : pol .ResourceID ,
489+ Type : mapResourceTypeToAuditType (pol .ResourceType ),
490+ Name : resourceName ,
491+ },
492+ & AuditTarget {
493+ ID : pol .ID ,
494+ Type : auditrecord .PolicyType ,
495+ Metadata : targetMetadata ,
496+ },
497+ orgID ,
498+ nil ,
499+ timestamp ,
500+ )
501+ }
502+
503+ // getPolicyByConstraint fetches a policy by unique constraint fields
504+ // Returns the policy and true if found, empty policy and false if not found
505+ func (r PolicyRepository ) getPolicyByConstraint (ctx context.Context , pol policy.Policy ) (Policy , bool ) {
506+ query , params , _ := dialect .From (TABLE_POLICIES ).
507+ Select ("id" , "resource_type" , "resource_id" , "principal_id" , "principal_type" , "role_id" ).
508+ Where (goqu.Ex {
509+ "role_id" : pol .RoleID ,
510+ "resource_id" : pol .ResourceID ,
511+ "resource_type" : pol .ResourceType ,
512+ "principal_id" : pol .PrincipalID ,
513+ "principal_type" : pol .PrincipalType ,
514+ }).
515+ Limit (1 ).
516+ ToSQL ()
517+
518+ var existing Policy
519+ if err := r .dbc .QueryRowxContext (ctx , query , params ... ).StructScan (& existing ); err != nil {
520+ return Policy {}, false
521+ }
522+ return existing , true
523+ }
524+
525+ // getResourceInfo fetches org ID and resource name based on resource type
526+ func (r PolicyRepository ) getResourceInfo (ctx context.Context , tx * sqlx.Tx , resourceType , resourceID string ) (string , string ) {
527+ var orgID , resourceName string
528+ switch resourceType {
529+ case schema .OrganizationNamespace :
530+ orgID = resourceID
531+ orgQuery , orgParams , _ := dialect .From (TABLE_ORGANIZATIONS ).
532+ Select ("title" ).
533+ Where (goqu.Ex {"id" : resourceID }).
534+ ToSQL ()
535+ _ = tx .QueryRowContext (ctx , orgQuery , orgParams ... ).Scan (& resourceName )
536+ case schema .ProjectNamespace :
537+ projQuery , projParams , _ := dialect .From (TABLE_PROJECTS ).
538+ Select ("org_id" , "title" ).
539+ Where (goqu.Ex {"id" : resourceID }).
540+ ToSQL ()
541+ _ = tx .QueryRowContext (ctx , projQuery , projParams ... ).Scan (& orgID , & resourceName )
542+ case schema .GroupNamespace :
543+ grpQuery , grpParams , _ := dialect .From (TABLE_GROUPS ).
544+ Select ("org_id" , "title" ).
545+ Where (goqu.Ex {"id" : resourceID }).
546+ ToSQL ()
547+ _ = tx .QueryRowContext (ctx , grpQuery , grpParams ... ).Scan (& orgID , & resourceName )
548+ }
549+ return orgID , resourceName
550+ }
551+
552+ // mapResourceTypeToAuditType maps resource namespace to audit entity type
553+ func mapResourceTypeToAuditType (resourceType string ) auditrecord.EntityType {
554+ switch resourceType {
555+ case schema .OrganizationNamespace :
556+ return auditrecord .OrganizationType
557+ case schema .ProjectNamespace :
558+ return auditrecord .ProjectType
559+ case schema .GroupNamespace :
560+ return auditrecord .GroupType
561+ default :
562+ return auditrecord .EntityType (resourceType )
563+ }
564+ }
0 commit comments