diff --git a/biz/adaptor/controller/core_api/conversation.go b/biz/adaptor/controller/core_api/conversation.go index 9e5940f..00c4bb8 100644 --- a/biz/adaptor/controller/core_api/conversation.go +++ b/biz/adaptor/controller/core_api/conversation.go @@ -7,6 +7,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/xh-polaris/psych-core-api/biz/adaptor/middleware" "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" "github.com/xh-polaris/psych-core-api/pkg/httpx" "github.com/xh-polaris/psych-core-api/provider" @@ -23,6 +24,7 @@ func CreateConversation(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.ConversationService.CreateConversation(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -39,6 +41,7 @@ func ListConversations(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.ConversationService.ListConversations(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -55,6 +58,7 @@ func GetConversation(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.ConversationService.GetConversation(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) diff --git a/biz/adaptor/controller/core_api/core_api.go b/biz/adaptor/controller/core_api/core_api.go index afb926a..45c2352 100644 --- a/biz/adaptor/controller/core_api/core_api.go +++ b/biz/adaptor/controller/core_api/core_api.go @@ -2,13 +2,14 @@ package core_api import ( "context" + "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/xh-polaris/psych-core-api/biz/adaptor/middleware" "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" "github.com/xh-polaris/psych-core-api/biz/cst" "github.com/xh-polaris/psych-core-api/pkg/httpx" "github.com/xh-polaris/psych-core-api/provider" - //"github.com/xh-polaris/psych-idl/kitex_gen/core_api" ) // ========================================== @@ -34,6 +35,7 @@ func DashboardGetDataOverview(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardGetDataOverview(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -58,6 +60,7 @@ func DashboardGetDataTrend(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardGetDataTrend(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -81,6 +84,7 @@ func DashboardListUnits(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardListUnits(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -105,6 +109,7 @@ func DashboardGetPsychTrend(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardGetPsychTrend(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -129,6 +134,7 @@ func DashboardGetAlarmOverview(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.AlarmService.Overview(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -158,6 +164,7 @@ func DashboardListAlarmRecords(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.AlarmService.ListRecords(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -184,6 +191,7 @@ func DashboardListClasses(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardListClasses(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -213,6 +221,7 @@ func DashboardListUsers(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardListUsers(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -594,6 +603,7 @@ func DashboardUserConvRecords(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardUserConvRecords(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -610,6 +620,7 @@ func DashboardUpdateAlarm(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.AlarmService.UpdateAlarm(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) @@ -626,6 +637,7 @@ func DashboardGetReport(ctx context.Context, c *app.RequestContext) { return } + middleware.StoreToken(ctx, c, &req) p := provider.Get() resp, err := p.DashboardService.DashboardGetReport(ctx, &req) httpx.PostProcess(ctx, c, &req, resp, err) diff --git a/biz/adaptor/middleware/auth.go b/biz/adaptor/middleware/auth.go new file mode 100644 index 0000000..0e5f870 --- /dev/null +++ b/biz/adaptor/middleware/auth.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "context" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/xh-polaris/psych-core-api/biz/cst" + "github.com/xh-polaris/psych-core-api/biz/infra/util" + "github.com/xh-polaris/psych-core-api/pkg/errorx" + "github.com/xh-polaris/psych-core-api/pkg/httpx" + "github.com/xh-polaris/psych-core-api/types/errno" +) + +func StoreToken(ctx context.Context, c *app.RequestContext, req any) { + authHeader := c.GetHeader("Authorization") + if len(authHeader) == 0 { + httpx.PostProcess(ctx, c, req, nil, errorx.New(errno.ErrUnAuth)) + c.Abort() + return + } + + // 验证JWT的有效性 + _, err := util.ParseJwt(string(authHeader)) + if err != nil { + httpx.PostProcess(ctx, c, req, nil, errorx.New(errno.ErrJWTPrase)) + c.Abort() + return + } + + // 使用context.WithValue传递token + newCtx := context.WithValue(ctx, cst.CtxKeyToken, string(authHeader)) + c.Set(cst.CtxKeyToken, newCtx) + c.Next(ctx) +} diff --git a/biz/application/service/alarm.go b/biz/application/service/alarm.go index feb483b..9c0a41c 100644 --- a/biz/application/service/alarm.go +++ b/biz/application/service/alarm.go @@ -2,11 +2,12 @@ package service import ( "context" - "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" - "github.com/xh-polaris/psych-core-api/biz/infra/util" "sync" "time" + "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" + "github.com/xh-polaris/psych-core-api/biz/infra/util" + "github.com/xh-polaris/psych-core-api/biz/infra/mapper/conversation" "github.com/xh-polaris/psych-core-api/biz/infra/mapper/report" @@ -40,6 +41,22 @@ var AlarmServiceSet = wire.NewSet( ) func (s *AlarmService) Overview(ctx context.Context, req *core_api.DashboardGetAlarmOverviewReq) (resp *core_api.DashboardGetAlarmOverviewResp, err error) { + // 鉴权 + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + + if req.UnitId != "" { + if !userMeta.HasUnitAdminAuth() || userMeta.UserId != req.UnitId { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + } + if req.UnitId == "" && !userMeta.HasSuperAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + + // 提取unitID unitOID, err := bson.ObjectIDFromHex(req.UnitId) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) @@ -48,7 +65,7 @@ func (s *AlarmService) Overview(ctx context.Context, req *core_api.DashboardGetA st, err := s.AlarmMapper.AggregateStats(ctx, unitOID, time.Time{}, time.Time{}) if err != nil { logs.Errorf("aggregate alarm error: %s", errorx.ErrorWithoutStack(err)) - return nil, err + return nil, errorx.New(errno.ErrDashboardAlarmUserStat) } return &core_api.DashboardGetAlarmOverviewResp{ @@ -60,12 +77,28 @@ func (s *AlarmService) Overview(ctx context.Context, req *core_api.DashboardGetA ProcessedChange: st.ProcessedChange, PendingChange: st.PendingChange, TrackChange: st.TrackChange, - Code: 200, + Code: 0, Msg: "success", }, nil } func (s *AlarmService) ListRecords(ctx context.Context, req *core_api.DashboardListAlarmRecordsReq) (resp *core_api.DashboardListAlarmRecordsResp, err error) { + // 鉴权 + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + + if req.UnitId != "" { + if !userMeta.HasUnitAdminAuth() || userMeta.UserId != req.UnitId { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + } + if req.UnitId == "" && !userMeta.HasSuperAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + + // 提取unitID unitOID, err := bson.ObjectIDFromHex(req.UnitId) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) @@ -76,7 +109,7 @@ func (s *AlarmService) ListRecords(ctx context.Context, req *core_api.DashboardL if total == 0 { return &core_api.DashboardListAlarmRecordsResp{ Pagination: util.PaginationRes(total, req.PaginationOptions), - Code: 200, + Code: 0, Msg: "success", }, nil } @@ -94,7 +127,7 @@ func (s *AlarmService) ListRecords(ctx context.Context, req *core_api.DashboardL return &core_api.DashboardListAlarmRecordsResp{ Records: completeAlarm, Pagination: util.PaginationRes(total, req.PaginationOptions), - Code: 200, + Code: 0, Msg: "success", }, err2 } @@ -180,6 +213,16 @@ func (s *AlarmService) completeAlarm(ctx context.Context, dbAlarms []*alarm.Alar } func (s *AlarmService) UpdateAlarm(ctx context.Context, req *core_api.DashboardUpdateAlarmReq) (resp *core_api.DashboardUpdateAlarmResp, err error) { + // 初步鉴权-需要有UnitAdmin权限 + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + + if !userMeta.HasUnitAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + // 参数校验 if req.Alarm == nil { return nil, errorx.New(errno.ErrMissingParams, errorx.KV("field", "预警信息")) @@ -192,6 +235,17 @@ func (s *AlarmService) UpdateAlarm(ctx context.Context, req *core_api.DashboardU return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "预警ID")) } + // 二次鉴权:需要在统一unit下 + oldAlarm, err := s.AlarmMapper.FindOneById(ctx, alarmId) + // optimize 查不到时考虑直接创建而非报错 + if err != nil { + logs.Errorf("find alarm error: %s", errorx.ErrorWithoutStack(err)) + return nil, errorx.New(errno.ErrNotFound) + } + if userMeta.UnitId != oldAlarm.UnitID.Hex() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + // 构建更新字段 update := bson.M{} @@ -225,13 +279,13 @@ func (s *AlarmService) UpdateAlarm(ctx context.Context, req *core_api.DashboardU if len(update) > 0 { if err = s.AlarmMapper.UpdateFields(ctx, alarmId, update); err != nil { logs.Errorf("update alarm error: %s", errorx.ErrorWithoutStack(err)) - return nil, err + return nil, errorx.New(errno.ErrInternalError) } } // 构造返回结果 return &core_api.DashboardUpdateAlarmResp{ - Code: 200, + Code: 0, Msg: "success", }, nil } diff --git a/biz/application/service/conversation.go b/biz/application/service/conversation.go index 2cc1f60..b37d577 100644 --- a/biz/application/service/conversation.go +++ b/biz/application/service/conversation.go @@ -2,6 +2,8 @@ package service import ( "context" + "time" + "github.com/google/wire" "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" "github.com/xh-polaris/psych-core-api/biz/cst" @@ -12,7 +14,6 @@ import ( "github.com/xh-polaris/psych-core-api/pkg/errorx" "github.com/xh-polaris/psych-core-api/types/errno" "go.mongodb.org/mongo-driver/v2/bson" - "time" ) type IConversationService interface { @@ -33,15 +34,20 @@ var ConversationServiceSet = wire.NewSet( ) func (c *ConversationService) CreateConversation(ctx context.Context, req *core_api.CreateConversationReq) (resp *core_api.CreateConversationResp, err error) { - //userMeta, err := util.ExtraUserMeta(ctx) - //if err != nil { - // return nil, err - //} + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + + userOID, err := bson.ObjectIDFromHex(userMeta.UserId) + if err != nil { + return nil, errorx.New(errno.ErrInvalidParams) + } temp := bson.NewObjectID() if err := c.ConversationMapper.Insert(ctx, &conversation.Conversation{ - ID: temp, - //UserID: userMeta.UserId, // TODO 鉴权时从token获得 + ID: temp, + UserID: userOID, CreateTime: time.Now(), UpdateTime: time.Now(), }); err != nil { @@ -50,19 +56,21 @@ func (c *ConversationService) CreateConversation(ctx context.Context, req *core_ return &core_api.CreateConversationResp{ ConversationId: temp.Hex(), - Code: 200, + Code: 0, Msg: "success", }, nil } func (c *ConversationService) ListConversations(ctx context.Context, req *core_api.ListConversationsReq) (resp *core_api.ListConversationsResp, err error) { - //userMeta, err := util.ExtraUserMeta(ctx) - //if err != nil { - // return nil, err - //} + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } - userId := bson.NewObjectID() // TODO - userId, _ = bson.ObjectIDFromHex("69abcc4f7f113a15afc12fda") + userId, err := bson.ObjectIDFromHex(userMeta.UserId) + if err != nil { + return nil, errorx.New(errno.ErrInvalidParams) + } total, err := c.ConversationMapper.CountByUser(ctx, userId) if err != nil { return nil, errorx.New(errno.ErrListConversation) @@ -71,7 +79,7 @@ func (c *ConversationService) ListConversations(ctx context.Context, req *core_a if total == 0 { return &core_api.ListConversationsResp{ Pagination: util.PaginationRes(0, req.PaginationOptions), - Code: 200, + Code: 0, Msg: "success", }, nil } @@ -96,18 +104,12 @@ func (c *ConversationService) ListConversations(ctx context.Context, req *core_a return &core_api.ListConversationsResp{ Pagination: util.PaginationRes(total, req.PaginationOptions), ConversationList: convs, - Code: 200, + Code: 0, Msg: "success", }, nil } func (c *ConversationService) GetConversation(ctx context.Context, req *core_api.GetConversationReq) (resp *core_api.GetConversationResp, err error) { - // 鉴权 - //userMeta, err := util.ExtraUserMeta(ctx) - //if err != nil { - // return nil, err - //} - // RetrieveMessage仅返回意外异常 消息搜索结果为空时返回空切片,和nil err // 非空时,返回index倒序的列表 rawMsgs, err := his.Mgr.RetrieveMessage(ctx, req.ConversationId, -1) diff --git a/biz/application/service/dashboard.go b/biz/application/service/dashboard.go index 7f8591f..4fec114 100644 --- a/biz/application/service/dashboard.go +++ b/biz/application/service/dashboard.go @@ -3,14 +3,15 @@ package service import ( "context" "errors" - "github.com/xh-polaris/psych-core-api/biz/application/dto/basic" - "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" - "github.com/xh-polaris/psych-core-api/biz/infra/util" "sort" "strconv" "sync" "time" + "github.com/xh-polaris/psych-core-api/biz/application/dto/basic" + "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" + "github.com/xh-polaris/psych-core-api/biz/infra/util" + "go.mongodb.org/mongo-driver/v2/mongo" "github.com/xh-polaris/psych-core-api/biz/domain/his" @@ -66,19 +67,36 @@ var DashboardServiceSet = wire.NewSet( ) func (s *DashboardService) DashboardGetDataOverview(ctx context.Context, req *core_api.DashboardGetDataOverviewReq) (*core_api.DashboardGetDataOverviewResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + now := time.Now() weekBefore := now.AddDate(0, 0, -7) twoWeeksBefore := now.AddDate(0, 0, -14) // 区分管理端 / 单位端 if req.UnitId == nil || req.GetUnitId() == "" { + // 管理端 - 需要管理员权限 + if !userMeta.HasUnitAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } return s.dashboardOverviewAdmin(ctx, twoWeeksBefore, weekBefore, now) } + // 单位端 - 检查用户是否属于该单位 unitOID, err := bson.ObjectIDFromHex(req.GetUnitId()) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) } + + // 验证用户是否属于该单位(如果不是管理员) + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != req.GetUnitId() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + return s.dashboardOverviewUnit(ctx, unitOID, twoWeeksBefore, weekBefore, now) } @@ -329,6 +347,12 @@ func (s *DashboardService) dashboardOverviewUnit(ctx context.Context, unitOID bs } func (s *DashboardService) DashboardGetDataTrend(ctx context.Context, req *core_api.DashboardGetDataTrendReq) (*core_api.DashboardGetDataTrendResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + now := time.Now() // 计算本周一 00:00 和下周一 00:00(用于按周内 7 天切分) // Go 的 Weekday: Sunday=0, Monday=1 ... Saturday=6 @@ -342,11 +366,20 @@ func (s *DashboardService) DashboardGetDataTrend(ctx context.Context, req *core_ var unitOID *bson.ObjectID if req.UnitId != nil && req.GetUnitId() != "" { + // 单位端 - 验证用户权限 + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != req.GetUnitId() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } id, err := bson.ObjectIDFromHex(req.GetUnitId()) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) } unitOID = &id + } else { + // 管理端 - 需要管理员权限 + if !userMeta.HasUnitAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } } // 活跃趋势(按天) @@ -450,6 +483,17 @@ func (s *DashboardService) DashboardGetDataTrend(ctx context.Context, req *core_ } func (s *DashboardService) DashboardListUnits(ctx context.Context, req *core_api.DashboardListUnitsReq) (*core_api.DashboardListUnitsResp, error) { + // 提取用户Meta并检查管理员权限 + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + + // 需要管理员权限 + if !userMeta.HasUnitAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + // 查询所有单位 units, err := s.UnitMapper.FindAll(ctx) if err != nil { @@ -504,14 +548,29 @@ func (s *DashboardService) DashboardListUnits(ctx context.Context, req *core_api } func (s *DashboardService) DashboardGetPsychTrend(ctx context.Context, req *core_api.DashboardGetPsychTrendReq) (*core_api.DashboardGetPsychTrendResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + unitIdStr := req.GetUnitId() var unitOID *bson.ObjectID if unitIdStr != "" { + // 单位端 - 验证用户权限 + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != unitIdStr { + return nil, errorx.New(errno.ErrInsufficientAuth) + } id, err := bson.ObjectIDFromHex(unitIdStr) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) } unitOID = &id + } else { + // 管理端 - 需要管理员权限 + if !userMeta.HasUnitAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } } // 统计风险等级分布(按性别拆分) @@ -594,7 +653,7 @@ func (s *DashboardService) DashboardGetPsychTrend(ctx context.Context, req *core EmotionRatio: emoRatio, Risks: riskDistributions, Keywords: keywords, - Code: 200, + Code: 0, Msg: "success", }, nil } @@ -648,10 +707,21 @@ func (s *DashboardService) getKeywords(ctx context.Context, unitOID *bson.Object } func (s *DashboardService) DashboardListClasses(ctx context.Context, req *core_api.DashboardListClassesReq) (*core_api.DashboardListClassesResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + unitOID, err := bson.ObjectIDFromHex(req.UnitId) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) } + + // 验证用户权限 - 必须是管理员或者属于该单位 + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != req.UnitId { + return nil, errorx.New(errno.ErrInsufficientAuth) + } // 筛选参数 var grades, classes []int32 if req.Grade != nil { @@ -717,10 +787,21 @@ func aggregateAndSort(mapperRes []*user.ClassStatResult, clsTeachers user.ClassT } func (s *DashboardService) DashboardListUsers(ctx context.Context, req *core_api.DashboardListUsersReq) (*core_api.DashboardListUsersResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + unitOID, err := bson.ObjectIDFromHex(req.UnitId) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) } + + // 验证用户权限 - 必须是管理员或者属于该单位 + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != req.UnitId { + return nil, errorx.New(errno.ErrInsufficientAuth) + } // 查找所有用户并按风险高→低排序 dbUsers, err := s.UserMapper.FindAllByUnitID(ctx, unitOID) if err != nil { @@ -823,18 +904,29 @@ func (s *DashboardService) completeRiskUser(ctx context.Context, pg *basic.Pagin } func (s *DashboardService) DashboardUserConvRecords(ctx context.Context, req *core_api.DashboardUserConvRecordsReq) (*core_api.DashboardUserConvRecordsResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + userOID, err := bson.ObjectIDFromHex(req.UserId) if err != nil { return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UserID"), errorx.KV("value", "用户ID")) } - // 获取用户基本信息 - usr, err := s.UserMapper.FindOneById(ctx, userOID) + // 首先获取目标用户信息以检查权限 + targetUser, err := s.UserMapper.FindOneById(ctx, userOID) if err != nil { logs.Errorf("get user info error: %s", errorx.ErrorWithoutStack(err)) return nil, errorx.New(errno.ErrDashboardGetUserInfo) } + // 验证权限:要么是管理员,要么是同一单位的用户 + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != targetUser.UnitID.Hex() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } + // 获取用户对话频率趋势 userConvTrend, err := s.getUserConvTrend(ctx, userOID) if err != nil { @@ -849,16 +941,16 @@ func (s *DashboardService) DashboardUserConvRecords(ctx context.Context, req *co resp := &core_api.DashboardUserConvRecordsResp{ User: &core_api.User{ - Id: usr.ID.Hex(), - Name: usr.Name, - Gender: strconv.Itoa(usr.Gender), - Grade: usr.Grade, - Class: usr.Class, + Id: targetUser.ID.Hex(), + Name: targetUser.Name, + Gender: strconv.Itoa(targetUser.Gender), + Grade: targetUser.Grade, + Class: targetUser.Class, }, UserConvTrend: userConvTrend, ConvDetail: convDetail, Pagination: pagination, - Code: 200, + Code: 0, Msg: "success", } @@ -1015,9 +1107,27 @@ func (s *DashboardService) getPagedUserConvs(ctx context.Context, userOID bson.O } func (s *DashboardService) DashboardGetReport(ctx context.Context, req *core_api.DashboardGetReportReq) (*core_api.DashboardGetReportResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + convOID, err := bson.ObjectIDFromHex(req.ConversationId) if err != nil { - return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "UnitID"), errorx.KV("value", "单位ID")) + return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "ConversationId"), errorx.KV("value", "对话ID")) + } + + // 获取对话信息以检查权限 + conv, err := s.ConversationMapper.FindOneById(ctx, convOID) + if err != nil { + logs.Errorf("get conversation error: %s", errorx.ErrorWithoutStack(err)) + return nil, errorx.New(errno.ErrNotFound, errorx.KV("field", "对话")) + } + + // 管理员可查看所有报告,普通用户只能查看自己的对话报告 + if !userMeta.HasUnitAdminAuth() && conv.UserID.Hex() != userMeta.UserId { + return nil, errorx.New(errno.ErrInsufficientAuth) } rpt, err := s.ReportMapper.FindByConversation(ctx, convOID) @@ -1033,16 +1143,30 @@ func (s *DashboardService) DashboardGetReport(ctx context.Context, req *core_api Emotion: rpt.Emotion, Body: rpt.Body, NeedAlarm: rpt.NeedAlarm, - Code: 200, + Code: 0, Msg: "success", }, nil } func (s *DashboardService) DashboardUnitConvRecords(ctx context.Context, req *core_api.DashboardUnitConvRecordsReq) (*core_api.DashboardUnitConvRecordsResp, error) { + // 提取用户Meta + userMeta, err := util.ExtraUserMeta(ctx) + if err != nil { + return nil, err + } + if uid := req.GetUnitId(); uid != "" { + // 单位端 - 验证用户权限 + if !userMeta.HasUnitAdminAuth() && userMeta.UnitId != uid { + return nil, errorx.New(errno.ErrInsufficientAuth) + } return s.getOneUnitConvs(ctx, req) } + // 管理端 - 需要管理员权限 + if !userMeta.HasUnitAdminAuth() { + return nil, errorx.New(errno.ErrInsufficientAuth) + } return s.getAllUnitsConvs(ctx, req) } diff --git a/biz/application/service/unit.go b/biz/application/service/unit.go index f318647..61b0ff5 100644 --- a/biz/application/service/unit.go +++ b/biz/application/service/unit.go @@ -2,9 +2,10 @@ package service import ( "context" + "time" + "github.com/xh-polaris/psych-core-api/biz/application/dto/basic" "github.com/xh-polaris/psych-core-api/biz/application/dto/core_api" - "time" "github.com/xh-polaris/psych-core-api/biz/cst" "github.com/xh-polaris/psych-core-api/biz/infra/mapper/unit" @@ -122,7 +123,10 @@ func (u *UnitService) UnitSignUp(ctx context.Context, req *core_api.UnitSignUpRe }, nil } +// UnitSignIn 单位Admin登录 func (u *UnitService) UnitSignIn(ctx context.Context, req *core_api.UnitSignInReq) (*core_api.UnitSignInResp, error) { + // 后续使用synapse4b Query Unit获得unitID + // 参数校验 if req.AuthId == "" { return nil, errorx.New(errno.ErrMissingParams, errorx.KV("field", "电话号码")) @@ -156,10 +160,13 @@ func (u *UnitService) UnitSignIn(ctx context.Context, req *core_api.UnitSignInRe return nil, errorx.New(errno.ErrWrongAccountOrPassword) } - // 获得密码 + // 校验密码 if !encrypt.BcryptCheck(req.VerifyCode, unitDAO.Password) { return nil, errorx.New(errno.ErrWrongAccountOrPassword) } + + // TODO 签发UnitAdmin jwt + // 验证码登录 case cst.AuthTypeCode: return nil, errorx.New(errno.ErrUnImplement) // TODO: 验证码登录 diff --git a/biz/application/service/user.go b/biz/application/service/user.go index ef2ebbd..d6af653 100644 --- a/biz/application/service/user.go +++ b/biz/application/service/user.go @@ -2,6 +2,7 @@ package service import ( "context" + "github.com/xh-polaris/psych-core-api/biz/infra/util" "time" "github.com/xh-polaris/psych-core-api/biz/application/dto/basic" @@ -72,12 +73,12 @@ func (u *UserService) UserSignUp(ctx context.Context, req *core_api.UserSignUpRe return nil, errorx.New(errno.ErrPhoneAlreadyExist) } - // 密码加密 - hashedPwd, err := encrypt.BcryptEncrypt(req.User.Password) - if err != nil { - logs.Errorf("bcrypt encrypt error: %s", errorx.ErrorWithoutStack(err)) - return nil, err - } + // 不加密直接明文存密码 + //hashedPwd, err := encrypt.BcryptEncrypt(req.User.Password) + //if err != nil { + // logs.Errorf("bcrypt encrypt error: %s", errorx.ErrorWithoutStack(err)) + // return nil, err + //} // 转换枚举值 gender, ok := enum.ParseGender(req.User.Gender) @@ -85,8 +86,9 @@ func (u *UserService) UserSignUp(ctx context.Context, req *core_api.UserSignUpRe return nil, errorx.New(errno.ErrInvalidParams, errorx.KV("field", "性别")) } - // 转换ID + // 转换unitID var unitId bson.ObjectID + var err error if req.User.UnitId != "" { unitId, err = bson.ObjectIDFromHex(req.User.UnitId) if err != nil { @@ -100,7 +102,7 @@ func (u *UserService) UserSignUp(ctx context.Context, req *core_api.UserSignUpRe ID: bson.NewObjectID(), CodeType: enum.CodeTypePhone, Code: req.User.Code, - Password: hashedPwd, + Password: req.User.Password, Name: req.User.Name, Birth: time.Unix(req.User.Birth, 0), Gender: gender, @@ -149,16 +151,13 @@ func (u *UserService) UserSignUp(ctx context.Context, req *core_api.UserSignUpRe CreateTime: userDAO.CreateTime.Unix(), UpdateTime: userDAO.UpdateTime.Unix(), }, - Code: 200, + Code: 0, Msg: "success", }, nil } func (u *UserService) UserSignIn(ctx context.Context, req *core_api.UserSignInReq) (*core_api.UserSignInResp, error) { // 参数校验 - //if req.AuthType == "" { - // return nil, errorx.New(errno.ErrMissingParams, errorx.KV("field", "登录方式")) - //} if req.AuthId == "" { return nil, errorx.New(errno.ErrMissingParams, errorx.KV("field", "账号")) } @@ -168,7 +167,7 @@ func (u *UserService) UserSignIn(ctx context.Context, req *core_api.UserSignInRe switch req.AuthType { case cst.AuthTypeCode: return nil, errorx.New(errno.ErrUnImplement) // TODO: 验证码登录 - case cst.AuthTypePassword: + case cst.AuthTypePassword: // if req.VerifyCode == "" { return nil, errorx.New(errno.ErrMissingParams, errorx.KV("field", "密码")) } @@ -193,17 +192,30 @@ func (u *UserService) UserSignIn(ctx context.Context, req *core_api.UserSignInRe return nil, errorx.New(errno.ErrWrongAccountOrPassword) } - // 密码验证 - if !encrypt.BcryptCheck(req.VerifyCode, userDAO.Password) { + // 明文密码验证 + if req.VerifyCode != userDAO.Password { return nil, errorx.New(errno.ErrWrongAccountOrPassword) } codeType, _ := enum.GetCodeType(userDAO.CodeType) + + // 签发jwt + token, err := util.GenerateJwt(map[string]any{ + cst.JsonUnitID: req.UnitId, + cst.JsonUserID: userDAO.ID.Hex(), + cst.JsonCode: userDAO.Code, // 手机号或学号 后续可能需要区分 + cst.JsonAdmin: userDAO.Role, + }) + if err != nil { + logs.Errorf("generate token for UserSignIn error: %s", errorx.ErrorWithoutStack(err)) + } + return &core_api.UserSignInResp{ UnitId: userDAO.UnitID.Hex(), UserId: userDAO.ID.Hex(), CodeValue: userDAO.Code, CodeType: codeType, - Code: 200, + Token: token, + Code: 0, Msg: "success", }, nil } @@ -265,7 +277,7 @@ func (u *UserService) UserGetInfo(ctx context.Context, req *core_api.UserGetInfo UpdateTime: userDAO.UpdateTime.Unix(), DeleteTime: userDAO.DeleteTime.Unix(), }, - Code: 200, + Code: 0, Msg: "success", }, nil } @@ -330,7 +342,7 @@ func (u *UserService) UserUpdateInfo(ctx context.Context, req *core_api.UserUpda // 构造返回结果 return &basic.Response{ - Code: 200, + Code: 0, Msg: "success", }, nil } @@ -396,7 +408,7 @@ func (u *UserService) UserUpdatePassword(ctx context.Context, req *core_api.User // 构造返回结果 return &basic.Response{ - Code: 200, + Code: 0, Msg: "success", }, nil } diff --git a/biz/cst/consts.go b/biz/cst/consts.go index 09cca87..d182d24 100644 --- a/biz/cst/consts.go +++ b/biz/cst/consts.go @@ -13,6 +13,15 @@ const ( // Tool is the role of a tool, means the message is a tool call output. Tool = "tool" ToolEnum = 3 + + CtxKeyToken = "token" + + // userMeta-权限相关枚举值 与User.Role一致 + AuthLevelUnitStudent = 0 + AuthLevelUnitTeacher = 1 + AuthLevelUnitClassTeacher = 2 + AuthLevelUnitAdmin = 3 + AuthLevelSuperAdmin = 4 ) // json字段枚举 @@ -21,6 +30,7 @@ const ( JsonUnitID = "unitId" JsonConversationID = "conversationId" JsonCode = "code" + JsonAdmin = "admin" ) // mapper层字段枚举 diff --git a/biz/domain/usr/usr.go b/biz/domain/usr/usr.go index 279054d..6bbcecd 100644 --- a/biz/domain/usr/usr.go +++ b/biz/domain/usr/usr.go @@ -1,8 +1,22 @@ package usr +import "github.com/xh-polaris/psych-core-api/biz/cst" + type Meta struct { - UserId string `json:"user_id"` - UnitId string `json:"unit_id;omitempty"` + UserId string `json:"userId"` + UnitId string `json:"unitId;omitempty"` Code string `json:"code;omitempty"` Admin int `json:"admin"` // 权限等级(学生用户、学校管理、超管) } + +func (usrMeta *Meta) HasUnitTeacherAuth() bool { + return usrMeta.Admin >= cst.AuthLevelUnitTeacher +} + +func (usrMeta *Meta) HasUnitAdminAuth() bool { + return usrMeta.Admin >= cst.AuthLevelUnitAdmin +} + +func (usrMeta *Meta) HasSuperAdminAuth() bool { + return usrMeta.Admin >= cst.AuthLevelSuperAdmin +} diff --git a/biz/infra/mapper/alarm/mapper.go b/biz/infra/mapper/alarm/mapper.go index 730a8cc..d7e370c 100644 --- a/biz/infra/mapper/alarm/mapper.go +++ b/biz/infra/mapper/alarm/mapper.go @@ -27,6 +27,7 @@ const ( type IMongoMapper interface { Insert(ctx context.Context, alarm *Alarm) error + FindOneById(ctx context.Context, id bson.ObjectID) (*Alarm, error) UpdateFields(ctx context.Context, id bson.ObjectID, update bson.M) error RetrieveByTime(ctx context.Context, unitID bson.ObjectID, start, end time.Time, opt *options.FindOptionsBuilder) ([]*Alarm, error) CountByTime(ctx context.Context, unitID bson.ObjectID, start, end time.Time) (int32, error) diff --git a/biz/infra/mapper/conversation/mapper.go b/biz/infra/mapper/conversation/mapper.go index 54a9895..21457d3 100644 --- a/biz/infra/mapper/conversation/mapper.go +++ b/biz/infra/mapper/conversation/mapper.go @@ -2,9 +2,10 @@ package conversation import ( "context" + "time" + "github.com/xh-polaris/psych-core-api/biz/infra/mapper" "go.mongodb.org/mongo-driver/v2/mongo/options" - "time" "github.com/xh-polaris/psych-core-api/biz/conf" "github.com/xh-polaris/psych-core-api/biz/cst" @@ -29,6 +30,7 @@ type IMongoMapper interface { CountByUnit(ctx context.Context, unitId *bson.ObjectID) (int32, error) CountByUser(ctx context.Context, userId bson.ObjectID) (int32, error) // 查找 + FindOneById(ctx context.Context, id bson.ObjectID) (*Conversation, error) FindManyByUserId(ctx context.Context, userId bson.ObjectID, opt options.Lister[options.FindOptions]) ([]*Conversation, error) // 分页查找 FindAllByUserId(ctx context.Context, userId bson.ObjectID) ([]*Conversation, error) // 查找全部 // 聚合统计 diff --git a/biz/infra/util/jwt_util.go b/biz/infra/util/jwt_util.go index 3dc390b..90ab5d2 100644 --- a/biz/infra/util/jwt_util.go +++ b/biz/infra/util/jwt_util.go @@ -40,6 +40,7 @@ func ParseJwt(jwtStr string, options ...jwt.ParserOption) (jwt.MapClaims, error) return token.Claims.(jwt.MapClaims), nil } +// ExtraUserMeta 从ctx中提取出userId func ExtraUserMeta(ctx context.Context) (m *usr.Meta, err error) { var meta usr.Meta var c *app.RequestContext @@ -53,5 +54,7 @@ func ExtraUserMeta(ctx context.Context) (m *usr.Meta, err error) { meta.UserId = claims[cst.JsonUserID].(string) meta.UnitId = claims[cst.JsonUnitID].(string) meta.Code = claims[cst.JsonCode].(string) + meta.Admin = int(claims[cst.JsonAdmin].(float64)) + return &meta, nil } diff --git a/go.mod b/go.mod index 07b747a..ba890e4 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/hertz-contrib/websocket v0.2.0 github.com/rabbitmq/amqp091-go v1.9.0 github.com/xh-polaris/gopkg v0.0.0-20250312141711-7327267f4ea6 - github.com/xh-polaris/psych-idl v0.0.0-20260307124857-2a4d7dea5233 + github.com/xh-polaris/psych-idl v0.0.0-20260311165400-263c29b44b0f github.com/zeromicro/go-zero v1.9.0 go.opentelemetry.io/contrib/propagators/b3 v1.37.0 go.opentelemetry.io/otel v1.38.0 diff --git a/go.sum b/go.sum index 43e830a..3f3cedc 100644 --- a/go.sum +++ b/go.sum @@ -526,6 +526,12 @@ github.com/xh-polaris/gopkg v0.0.0-20250312141711-7327267f4ea6 h1:5LzTSKpK7rMv9T github.com/xh-polaris/gopkg v0.0.0-20250312141711-7327267f4ea6/go.mod h1:C+TEAEky4WkrDpxzCDCxVi968lxZ15HwrdVxWrbL42Y= github.com/xh-polaris/psych-idl v0.0.0-20260307124857-2a4d7dea5233 h1:lHiuChLDKd1t7yuTV+coHjLSE+9fADBmf/MeiR3d/p4= github.com/xh-polaris/psych-idl v0.0.0-20260307124857-2a4d7dea5233/go.mod h1:Mq9OKYzoq5fzibYWoxdO0xsybjShIVJ4XLKu/IpVWHw= +github.com/xh-polaris/psych-idl v0.0.0-20260311143959-eb88f4b7d0ff h1:+ke9nPSPkVMDMOmZqHX+9iuXLAoIi/bYk8BirqTIT40= +github.com/xh-polaris/psych-idl v0.0.0-20260311143959-eb88f4b7d0ff/go.mod h1:Mq9OKYzoq5fzibYWoxdO0xsybjShIVJ4XLKu/IpVWHw= +github.com/xh-polaris/psych-idl v0.0.0-20260311163535-7c73dd746f6f h1:KijRA+MqjVqcQ17z2UtPdZky20NSZL9M2JAhPK+l0Kc= +github.com/xh-polaris/psych-idl v0.0.0-20260311163535-7c73dd746f6f/go.mod h1:Mq9OKYzoq5fzibYWoxdO0xsybjShIVJ4XLKu/IpVWHw= +github.com/xh-polaris/psych-idl v0.0.0-20260311165400-263c29b44b0f h1:hbPtO2ZWbAxFMPxJPMhJxObzlBxSaVgtk/tCGN7oJcQ= +github.com/xh-polaris/psych-idl v0.0.0-20260311165400-263c29b44b0f/go.mod h1:Mq9OKYzoq5fzibYWoxdO0xsybjShIVJ4XLKu/IpVWHw= github.com/yanyiwu/gojieba v1.4.6 h1:9oKbZijSHBdoTabXK34romSWj4aQLvs+j1ctIQjSxPk= github.com/yanyiwu/gojieba v1.4.6/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY= github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= diff --git a/types/errno/common.go b/types/errno/common.go index 0beb001..a54cb68 100644 --- a/types/errno/common.go +++ b/types/errno/common.go @@ -21,6 +21,7 @@ const ( ErrPhoneAlreadyExist = 1009 ErrWrongPassword = 1010 ErrJWTPrase = 1011 + ErrInsufficientAuth = 1012 ) func init() { @@ -94,4 +95,9 @@ func init() { "JWT解析错误", code.WithAffectStability(false), ) + code.Register( + ErrInsufficientAuth, + "权限不足", + code.WithAffectStability(false), + ) } diff --git a/types/errno/dashboard.go b/types/errno/dashboard.go index ea9fcff..663c539 100644 --- a/types/errno/dashboard.go +++ b/types/errno/dashboard.go @@ -20,6 +20,7 @@ const ( ErrDashboardGetConvReports = 5014 // 获取对话报表失败 ErrDashboardGenerateWordCloud = 5015 // 生成词云失败 ErrDashboardGetReport = 5016 // 获取报表失败 + ErrDashboardAlarmOverview = 5017 ) func init() { @@ -98,4 +99,9 @@ func init() { "获取报表失败", code.WithAffectStability(false), ) + code.Register( + ErrDashboardAlarmOverview, + "预警总览数据获取失败", + code.WithAffectStability(false), + ) }