Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 70 additions & 35 deletions src/go/cmd/strelka-frontend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,7 @@ var (
loggerOnce sync.Once
)

func DebugLog(msg string, fields ...zap.Field) {
// The value here doesn't matter, we'll just check for its existence
if _, enableDebugLogging := os.LookupEnv("ENABLE_VERBOSE_LOGGING"); !enableDebugLogging {
return
}

func getLogger() *zap.Logger {
loggerOnce.Do(func() {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.MessageKey = "message"
Expand All @@ -88,8 +83,29 @@ func DebugLog(msg string, fields ...zap.Field) {
loggerInstance = logger
})

return loggerInstance
}

func Infof(format string, v ...any) {
logger := getLogger()
if logger != nil {
logger.Sugar().Infof(format, v...)
} else {
log.Printf(format, v...)
}
}

func Trace(msg string, fields ...zap.Field) {
// The value here doesn't matter, we'll just check for its existence
if _, enableTraceLogging := os.LookupEnv("ENABLE_TRACE_LOGGING"); !enableTraceLogging {
return
}

traceMsg := fmt.Sprintf("[TRACE] %s", msg)
if loggerInstance != nil {
loggerInstance.Info(msg, fields...)
loggerInstance.Info(traceMsg, fields...)
} else {
log.Println(traceMsg)
}
}

Expand Down Expand Up @@ -217,7 +233,7 @@ func (s *server) ScanFile(stream strelka.Frontend_ScanFileServer) error {
}
}

startQueueTask := time.Now()
startQueueScanFileTaskTime := time.Now()
if err := s.coordinator.cli.ZAdd(
stream.Context(),
"tasks",
Expand All @@ -226,19 +242,19 @@ func (s *server) ScanFile(stream strelka.Frontend_ScanFileServer) error {
Member: id,
},
).Err(); err != nil {
DebugLog("[VERBOSE] Failed to queue task",
Trace("ScanFile task failed to queue",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Duration("took", time.Since(startQueueTask)),
zap.Int64("queue_time_ms", time.Since(startQueueScanFileTaskTime).Milliseconds()),
zap.Error(err),
)
return fmt.Errorf("sending task: %w", err)
} else {
DebugLog("[VERBOSE] Queued task",
Trace("ScanFile task queued",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Time("deadline", deadline),
zap.Duration("took", time.Since(startQueueTask)),
zap.Int64("deadline", deadline.Unix()),
zap.Int64("queue_time_ms", time.Since(startQueueScanFileTaskTime).Milliseconds()),
)
}

Expand All @@ -249,21 +265,36 @@ func (s *server) ScanFile(stream strelka.Frontend_ScanFileServer) error {
(*tx).Del(stream.Context(), sha)
}

startAwaitResponse := time.Now()
startResponseLoopTime := time.Now()
lastResponseTime := time.Now()
redisErrorCount := 0
for {
if err := stream.Context().Err(); err != nil {
DebugLog("[VERBOSE] Stream error",
Trace("stream closed",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Duration("waited", time.Since(startAwaitResponse)),
zap.Int64("since_last_response_ms", time.Since(lastResponseTime).Milliseconds()),
zap.Int64("since_entering_loop_ms", time.Since(startResponseLoopTime).Milliseconds()),
zap.Int("redis_error_count", redisErrorCount),
zap.Error(err),
)
return fmt.Errorf("context closed: %w", err)
}

startPopTime := time.Now()
res, err := s.coordinator.cli.BLPop(stream.Context(), 5*time.Second, keye).Result()
if err != nil {
if err != redis.Nil {
redisErrorCount += 1
Trace("error performing response pop",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Int64("since_last_response_ms", time.Since(lastResponseTime).Milliseconds()),
zap.Int64("since_entering_loop_ms", time.Since(startResponseLoopTime).Milliseconds()),
zap.Int64("response_pop_took_ms", time.Since(startPopTime).Milliseconds()),
zap.Int("redis_error_count", redisErrorCount),
zap.Error(err),
)
// Delay to prevent fast looping over errors
time.Sleep(250 * time.Millisecond)
}
Expand All @@ -276,19 +307,23 @@ func (s *server) ScanFile(stream strelka.Frontend_ScanFileServer) error {

lpop := res[1]
if lpop == "FIN" {
DebugLog("[VERBOSE] FIN response received",
Trace("FIN response received",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Duration("waited", time.Since(startAwaitResponse)),
zap.Int64("since_last_response_ms", time.Since(lastResponseTime).Milliseconds()),
zap.Int64("since_entering_loop_ms", time.Since(startResponseLoopTime).Milliseconds()),
zap.Int("redis_error_count", redisErrorCount),
)
break
}

DebugLog("[VERBOSE] Response received",
Trace("response received",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Duration("waited", time.Since(startAwaitResponse)),
zap.Int64("since_last_response_ms", time.Since(lastResponseTime).Milliseconds()),
zap.Int64("since_entering_loop_ms", time.Since(startResponseLoopTime).Milliseconds()),
)
lastResponseTime = time.Now()

if tx != nil {
(*tx).RPush(stream.Context(), sha, lpop)
Expand Down Expand Up @@ -318,20 +353,20 @@ func (s *server) ScanFile(stream strelka.Frontend_ScanFileServer) error {

s.responses <- resp

startSendResponse := time.Now()
startSendResponseTime := time.Now()
if err := stream.Send(resp); err != nil {
DebugLog("[VERBOSE] Error sending response",
Trace("error sending response to caller",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Duration("took", time.Since(startSendResponse)),
zap.Int64("send_response_took_ms", time.Since(startSendResponseTime).Milliseconds()),
zap.Error(err),
)
return fmt.Errorf("send stream: %w", err)
} else {
DebugLog("[VERBOSE] Sent response",
Trace("response sent to caller",
zap.String("request_id", req.Id),
zap.String("strelka_id", id),
zap.Duration("took", time.Since(startSendResponse)),
zap.Int64("send_response_took_ms", time.Since(startSendResponseTime).Milliseconds()),
)
}
}
Expand Down Expand Up @@ -416,7 +451,7 @@ func (s *server) CompileYara(stream strelka.Frontend_CompileYaraServer) error {
if err != nil {
if err != redis.Nil {
// Delay to prevent fast looping over errors
log.Printf("err: %v\n", err)
Infof("err: %v\n", err)
time.Sleep(250 * time.Millisecond)
}
continue
Expand Down Expand Up @@ -536,7 +571,7 @@ func (s *server) SyncYara(stream strelka.Frontend_SyncYaraServer) error {
if err != nil {
if err != redis.Nil {
// Delay to prevent fast looping over errors
log.Printf("err: %v\n", err)
Infof("err: %v\n", err)
time.Sleep(250 * time.Millisecond)
}
continue
Expand Down Expand Up @@ -684,7 +719,7 @@ func (s *server) SyncYaraV2(stream strelka.Frontend_SyncYaraV2Server) error {
if err != nil {
if err != redis.Nil {
// Delay to prevent fast looping over errors
log.Printf("err: %v\n", err)
Infof("err: %v\n", err)
time.Sleep(250 * time.Millisecond)
}
continue
Expand Down Expand Up @@ -870,17 +905,17 @@ func main() {
go func() {
rpc.LogResponses(responses, conf.Response.Log)
}()
log.Printf("responses will be logged to %v", conf.Response.Log)
Infof("responses will be logged to %v", conf.Response.Log)
} else if conf.Response.Report != 0 {
go func() {
rpc.ReportResponses(responses, conf.Response.Report)
}()
log.Printf("responses will be reported every %v", conf.Response.Report)
Infof("responses will be reported every %v", conf.Response.Report)
} else {
go func() {
rpc.DiscardResponses(responses)
}()
log.Println("responses will be discarded")
Infof("responses will be discarded")
}

cd := redis.NewClient(&redis.Options{
Expand Down Expand Up @@ -923,16 +958,16 @@ func main() {
}

go func() {
log.Printf("Waiting for shutdown\n")
Infof("Waiting for shutdown\n")
<-shutdownWorkersSig
st := time.Now()
log.Printf("Received shutdown signal, attempting graceful shutdown\n")
Infof("Received shutdown signal, attempting graceful shutdown\n")
s.GracefulStop()
log.Printf("Graceful shutdown completed in %v\n", time.Since(st))
Infof("Graceful shutdown completed in %v\n", time.Since(st))
}()

strelka.RegisterFrontendServer(s, opts)
grpc_health_v1.RegisterHealthServer(s, opts)
err = s.Serve(listen)
log.Printf("Shutting down. Serve err: %v\n", err)
Infof("Shutting down. Serve err: %v\n", err)
}
Loading