Skip to content

Commit 870845d

Browse files
committed
Production audit: fix race conditions, memory leaks, and security issues
Fixes: - Fix race condition in DynamicRateLimit middleware (thread-safe implementation) - Add missing logger field to EnhancedAnalyzer struct - Fix SSRF vulnerability by removing file:// URL scheme support - Fix SQL placeholder syntax (PostgreSQL $1 -> SQLite ?) in repository - Add batch job TTL cleanup to prevent unbounded memory growth Changes: - ratelimit.go: Rewrite DynamicRateLimit() to not mutate shared config - enhanced_analyzer.go: Add logger field and initialize in constructors - validator.go: Block file:// scheme for security - repository.go: Fix ListHLSAnalyses, ListReports, ListQualityComparisons - main.go: Add cleanupBatchJobs goroutine (1hr TTL, 5min cleanup interval)
1 parent d73fe9a commit 870845d

File tree

5 files changed

+165
-96
lines changed

5 files changed

+165
-96
lines changed

cmd/rendiff-probe/main.go

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ import (
4343

4444
// Production constants
4545
const (
46-
maxFileSize = 5 * 1024 * 1024 * 1024 // 5GB max file size
47-
maxRequestBodyMB = 10 // 10MB max JSON request body
48-
maxBatchItems = 100 // Max items in batch processing
49-
defaultTimeout = 60 * time.Second
50-
maxTimeout = 30 * time.Minute
51-
shutdownTimeout = 30 * time.Second
52-
wsReadBufferSize = 1024
53-
wsWriteBufferSize = 1024
46+
maxFileSize = 5 * 1024 * 1024 * 1024 // 5GB max file size
47+
maxRequestBodyMB = 10 // 10MB max JSON request body
48+
maxBatchItems = 100 // Max items in batch processing
49+
defaultTimeout = 60 * time.Second
50+
maxTimeout = 30 * time.Minute
51+
shutdownTimeout = 30 * time.Second
52+
wsReadBufferSize = 1024
53+
wsWriteBufferSize = 1024
54+
batchJobTTL = 1 * time.Hour // TTL for completed batch jobs before cleanup
55+
batchCleanupPeriod = 5 * time.Minute // How often to run batch job cleanup
5456
)
5557

5658
// Global instances for services
@@ -167,6 +169,10 @@ func main() {
167169

168170
appLogger.Info().Msg("All services initialized successfully")
169171

172+
// Start batch job cleanup goroutine
173+
go cleanupBatchJobs()
174+
appLogger.Info().Dur("ttl", batchJobTTL).Dur("period", batchCleanupPeriod).Msg("Batch job cleanup started")
175+
170176
// Create Gin router with production settings
171177
router := gin.New()
172178

@@ -285,6 +291,44 @@ func closeAllWebSocketConnections() {
285291
wsConnections = make(map[string]*websocket.Conn)
286292
}
287293

294+
// cleanupBatchJobs periodically removes expired batch jobs to prevent memory leaks
295+
func cleanupBatchJobs() {
296+
ticker := time.NewTicker(batchCleanupPeriod)
297+
defer ticker.Stop()
298+
299+
for {
300+
select {
301+
case <-shutdownCtx.Done():
302+
appLogger.Debug().Msg("Batch job cleanup goroutine stopped")
303+
return
304+
case <-ticker.C:
305+
now := time.Now()
306+
var toDelete []string
307+
308+
batchLock.RLock()
309+
for id, job := range batchJobs {
310+
// Remove completed/cancelled/failed jobs older than TTL
311+
if job.Status == "completed" || job.Status == "cancelled" || job.Status == "failed" {
312+
if now.Sub(job.UpdatedAt) > batchJobTTL {
313+
toDelete = append(toDelete, id)
314+
}
315+
}
316+
}
317+
batchLock.RUnlock()
318+
319+
if len(toDelete) > 0 {
320+
batchLock.Lock()
321+
for _, id := range toDelete {
322+
delete(batchJobs, id)
323+
appLogger.Debug().Str("job_id", id).Msg("Cleaned up expired batch job")
324+
}
325+
batchLock.Unlock()
326+
appLogger.Info().Int("count", len(toDelete)).Msg("Batch job cleanup completed")
327+
}
328+
}
329+
}
330+
}
331+
288332
// requestLoggingMiddleware logs HTTP requests
289333
func requestLoggingMiddleware() gin.HandlerFunc {
290334
return func(c *gin.Context) {

internal/database/repository.go

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,9 @@ func (r *SQLiteRepository) ListHLSAnalyses(ctx context.Context, userID *uuid.UUI
416416
baseQuery := `FROM hls_analyses h JOIN analyses a ON h.analysis_id = a.id`
417417
whereClause := ""
418418
args := []interface{}{}
419-
argCount := 0
420419

421420
if userID != nil {
422-
argCount++
423-
whereClause = fmt.Sprintf(" WHERE a.user_id = $%d", argCount)
421+
whereClause = " WHERE a.user_id = ?"
424422
args = append(args, *userID)
425423
}
426424

@@ -431,24 +429,19 @@ func (r *SQLiteRepository) ListHLSAnalyses(ctx context.Context, userID *uuid.UUI
431429
return nil, 0, err
432430
}
433431

434-
// Get paginated results
435-
argCount++
436-
limitArg := argCount
437-
args = append(args, limit)
438-
argCount++
439-
offsetArg := argCount
440-
args = append(args, offset)
432+
// Get paginated results - SQLite uses ? placeholders
433+
paginatedArgs := append(args, limit, offset)
441434

442435
query := fmt.Sprintf(`
443436
SELECT h.id, h.analysis_id, h.manifest_path, h.manifest_type, h.manifest_data,
444437
h.segment_count, h.total_duration, h.bitrate_variants, h.segment_duration,
445438
h.playlist_version, h.status, h.processing_time, h.created_at, h.completed_at, h.error_msg
446439
%s %s
447440
ORDER BY h.created_at DESC
448-
LIMIT $%d OFFSET $%d
449-
`, baseQuery, whereClause, limitArg, offsetArg)
441+
LIMIT ? OFFSET ?
442+
`, baseQuery, whereClause)
450443

451-
err = r.db.DB.SelectContext(ctx, &analyses, query, args...)
444+
err = r.db.DB.SelectContext(ctx, &analyses, query, paginatedArgs...)
452445
if err != nil {
453446
return nil, 0, err
454447
}
@@ -575,31 +568,26 @@ func (r *SQLiteRepository) ListReports(ctx context.Context, userID *uuid.UUID, a
575568
baseQuery := "FROM reports"
576569
whereConditions := []string{}
577570
args := []interface{}{}
578-
argCount := 0
579571

580572
if userID != nil {
581-
argCount++
582-
whereConditions = append(whereConditions, fmt.Sprintf("user_id = $%d", argCount))
573+
whereConditions = append(whereConditions, "user_id = ?")
583574
args = append(args, *userID)
584575
}
585576

586577
if analysisID != "" {
587578
if analysisUUID, err := uuid.Parse(analysisID); err == nil {
588-
argCount++
589-
whereConditions = append(whereConditions, fmt.Sprintf("analysis_id = $%d", argCount))
579+
whereConditions = append(whereConditions, "analysis_id = ?")
590580
args = append(args, analysisUUID)
591581
}
592582
}
593583

594584
if reportType != "" {
595-
argCount++
596-
whereConditions = append(whereConditions, fmt.Sprintf("report_type = $%d", argCount))
585+
whereConditions = append(whereConditions, "report_type = ?")
597586
args = append(args, reportType)
598587
}
599588

600589
if format != "" {
601-
argCount++
602-
whereConditions = append(whereConditions, fmt.Sprintf("format = $%d", argCount))
590+
whereConditions = append(whereConditions, "format = ?")
603591
args = append(args, format)
604592
}
605593

@@ -615,23 +603,18 @@ func (r *SQLiteRepository) ListReports(ctx context.Context, userID *uuid.UUID, a
615603
return nil, 0, err
616604
}
617605

618-
// Get paginated results
619-
argCount++
620-
limitArg := argCount
621-
args = append(args, limit)
622-
argCount++
623-
offsetArg := argCount
624-
args = append(args, offset)
606+
// Get paginated results - SQLite uses ? placeholders
607+
paginatedArgs := append(args, limit, offset)
625608

626609
query := fmt.Sprintf(`
627610
SELECT id, analysis_id, user_id, report_type, format, title, description,
628611
file_path, file_size, download_count, is_public, expires_at, created_at, last_download
629612
%s %s
630613
ORDER BY created_at DESC
631-
LIMIT $%d OFFSET $%d
632-
`, baseQuery, whereClause, limitArg, offsetArg)
614+
LIMIT ? OFFSET ?
615+
`, baseQuery, whereClause)
633616

634-
err = r.db.DB.SelectContext(ctx, &reports, query, args...)
617+
err = r.db.DB.SelectContext(ctx, &reports, query, paginatedArgs...)
635618
if err != nil {
636619
return nil, 0, err
637620
}
@@ -701,54 +684,50 @@ func (r *SQLiteRepository) UpdateQualityComparison(ctx context.Context, comparis
701684

702685
_, err := r.db.DB.ExecContext(
703686
ctx, query,
704-
comparison.ID, comparison.Status, comparison.ResultSummary, comparison.ProcessingTime,
705-
comparison.CompletedAt, comparison.ErrorMsg,
687+
comparison.Status, comparison.ResultSummary, comparison.ProcessingTime,
688+
comparison.CompletedAt, comparison.ErrorMsg, comparison.ID,
706689
)
707690
return err
708691
}
709692

710693
func (r *SQLiteRepository) UpdateQualityComparisonStatus(ctx context.Context, id uuid.UUID, status models.AnalysisStatus) error {
711694
query := "UPDATE quality_comparisons SET status = ?, updated_at = datetime('now') WHERE id = ?"
712-
_, err := r.db.DB.ExecContext(ctx, query, id, status)
695+
_, err := r.db.DB.ExecContext(ctx, query, status, id)
713696
return err
714697
}
715698

716699
func (r *SQLiteRepository) ListQualityComparisons(ctx context.Context, userID *uuid.UUID, referenceID, distortedID, status string, limit, offset int) ([]*models.QualityComparison, int, error) {
717700
var comparisons []*models.QualityComparison
718701
var total int
719702

720-
baseQuery := `FROM quality_comparisons qc
703+
baseQuery := `FROM quality_comparisons qc
721704
LEFT JOIN analyses ref ON qc.reference_id = ref.id
722705
LEFT JOIN analyses dist ON qc.distorted_id = dist.id`
723706
whereConditions := []string{}
724707
args := []interface{}{}
725-
argCount := 0
726708

727709
if userID != nil {
728-
argCount++
729-
whereConditions = append(whereConditions, fmt.Sprintf("(ref.user_id = $%d OR dist.user_id = $%d)", argCount, argCount))
730-
args = append(args, *userID)
710+
// SQLite uses ? placeholders - need to add the same arg twice for OR condition
711+
whereConditions = append(whereConditions, "(ref.user_id = ? OR dist.user_id = ?)")
712+
args = append(args, *userID, *userID)
731713
}
732714

733715
if referenceID != "" {
734716
if refUUID, err := uuid.Parse(referenceID); err == nil {
735-
argCount++
736-
whereConditions = append(whereConditions, fmt.Sprintf("qc.reference_id = $%d", argCount))
717+
whereConditions = append(whereConditions, "qc.reference_id = ?")
737718
args = append(args, refUUID)
738719
}
739720
}
740721

741722
if distortedID != "" {
742723
if distUUID, err := uuid.Parse(distortedID); err == nil {
743-
argCount++
744-
whereConditions = append(whereConditions, fmt.Sprintf("qc.distorted_id = $%d", argCount))
724+
whereConditions = append(whereConditions, "qc.distorted_id = ?")
745725
args = append(args, distUUID)
746726
}
747727
}
748728

749729
if status != "" {
750-
argCount++
751-
whereConditions = append(whereConditions, fmt.Sprintf("qc.status = $%d", argCount))
730+
whereConditions = append(whereConditions, "qc.status = ?")
752731
args = append(args, status)
753732
}
754733

@@ -764,23 +743,18 @@ func (r *SQLiteRepository) ListQualityComparisons(ctx context.Context, userID *u
764743
return nil, 0, err
765744
}
766745

767-
// Get paginated results
768-
argCount++
769-
limitArg := argCount
770-
args = append(args, limit)
771-
argCount++
772-
offsetArg := argCount
773-
args = append(args, offset)
746+
// Get paginated results - SQLite uses ? placeholders
747+
paginatedArgs := append(args, limit, offset)
774748

775749
query := fmt.Sprintf(`
776750
SELECT qc.id, qc.reference_id, qc.distorted_id, qc.comparison_type, qc.status,
777751
qc.result_summary, qc.processing_time, qc.created_at, qc.completed_at, qc.error_msg
778752
%s %s
779753
ORDER BY qc.created_at DESC
780-
LIMIT $%d OFFSET $%d
781-
`, baseQuery, whereClause, limitArg, offsetArg)
754+
LIMIT ? OFFSET ?
755+
`, baseQuery, whereClause)
782756

783-
err = r.db.DB.SelectContext(ctx, &comparisons, query, args...)
757+
err = r.db.DB.SelectContext(ctx, &comparisons, query, paginatedArgs...)
784758
if err != nil {
785759
return nil, 0, err
786760
}

internal/ffmpeg/enhanced_analyzer.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type EnhancedAnalyzer struct {
3030
pseAnalyzer *PSEAnalyzer
3131
streamDispositionAnalyzer *StreamDispositionAnalyzer
3232
dataIntegrityAnalyzer *DataIntegrityAnalyzer
33+
logger zerolog.Logger
3334
}
3435

3536
// NewEnhancedAnalyzer creates a new enhanced analyzer
@@ -53,6 +54,7 @@ func NewEnhancedAnalyzer(ffprobePath string, logger zerolog.Logger) *EnhancedAna
5354
pseAnalyzer: NewPSEAnalyzer(ffprobePath, logger),
5455
streamDispositionAnalyzer: NewStreamDispositionAnalyzer(ffprobePath, logger),
5556
dataIntegrityAnalyzer: NewDataIntegrityAnalyzer(ffprobePath, logger),
57+
logger: logger,
5658
}
5759
}
5860

@@ -78,6 +80,7 @@ func NewEnhancedAnalyzerWithContentAnalysis(ffmpegPath string, ffprobePath strin
7880
pseAnalyzer: NewPSEAnalyzer(ffprobePath, logger),
7981
streamDispositionAnalyzer: NewStreamDispositionAnalyzer(ffprobePath, logger),
8082
dataIntegrityAnalyzer: NewDataIntegrityAnalyzer(ffprobePath, logger),
83+
logger: logger,
8184
}
8285
}
8386

@@ -158,7 +161,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
158161
timecodeAnalysis, err := ea.timecodeAnalyzer.AnalyzeTimecode(ctx, filePath, result.Streams)
159162
if err != nil {
160163
// Log error but don't fail entire analysis - some files may not have timecode
161-
fmt.Printf("Warning: timecode analysis failed: %v\n", err)
164+
ea.logger.Warn().Err(err).Msg("timecode analysis failed")
162165
} else {
163166
result.EnhancedAnalysis.TimecodeAnalysis = timecodeAnalysis
164167
}
@@ -169,7 +172,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
169172
afdAnalysis, err := ea.afdAnalyzer.AnalyzeAFD(ctx, filePath, result.Streams)
170173
if err != nil {
171174
// Log error but don't fail entire analysis - some files may not have AFD
172-
fmt.Printf("Warning: AFD analysis failed: %v\n", err)
175+
ea.logger.Warn().Err(err).Msg("AFD analysis failed")
173176
} else {
174177
result.EnhancedAnalysis.AFDAnalysis = afdAnalysis
175178
}
@@ -180,7 +183,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
180183
transportAnalysis, err := ea.transportStreamAnalyzer.AnalyzeTransportStream(ctx, filePath, result.Streams, result.Format)
181184
if err != nil {
182185
// Log error but don't fail entire analysis - only applies to transport streams
183-
fmt.Printf("Warning: transport stream analysis failed: %v\n", err)
186+
ea.logger.Warn().Err(err).Msg("transport stream analysis failed")
184187
} else {
185188
result.EnhancedAnalysis.TransportStreamAnalysis = transportAnalysis
186189
}
@@ -191,7 +194,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
191194
endiannessAnalysis, err := ea.endiannessAnalyzer.AnalyzeEndianness(ctx, filePath, result.Streams, result.Format)
192195
if err != nil {
193196
// Log error but don't fail entire analysis - endianness may not be detectable for all formats
194-
fmt.Printf("Warning: endianness analysis failed: %v\n", err)
197+
ea.logger.Warn().Err(err).Msg("endianness analysis failed")
195198
} else {
196199
result.EnhancedAnalysis.EndiannessAnalysis = endiannessAnalysis
197200
}
@@ -202,7 +205,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
202205
audioWrappingAnalysis, err := ea.audioWrappingAnalyzer.AnalyzeAudioWrapping(ctx, filePath, result.Streams, result.Format)
203206
if err != nil {
204207
// Log error but don't fail entire analysis - not all formats have professional audio wrapping
205-
fmt.Printf("Warning: audio wrapping analysis failed: %v\n", err)
208+
ea.logger.Warn().Err(err).Msg("audio wrapping analysis failed")
206209
} else {
207210
result.EnhancedAnalysis.AudioWrappingAnalysis = audioWrappingAnalysis
208211
}
@@ -213,7 +216,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
213216
imfAnalysis, err := ea.imfAnalyzer.AnalyzeIMF(ctx, filePath)
214217
if err != nil {
215218
// Log error but don't fail entire analysis - only applies to IMF packages
216-
fmt.Printf("Warning: IMF analysis failed: %v\n", err)
219+
ea.logger.Warn().Err(err).Msg("IMF analysis failed")
217220
} else {
218221
result.EnhancedAnalysis.IMFAnalysis = imfAnalysis
219222
}
@@ -224,7 +227,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
224227
mxfAnalysis, err := ea.mxfAnalyzer.AnalyzeMXF(ctx, filePath)
225228
if err != nil {
226229
// Log error but don't fail entire analysis - only applies to MXF files
227-
fmt.Printf("Warning: MXF analysis failed: %v\n", err)
230+
ea.logger.Warn().Err(err).Msg("MXF analysis failed")
228231
} else {
229232
result.EnhancedAnalysis.MXFAnalysis = mxfAnalysis
230233
}
@@ -235,7 +238,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
235238
deadPixelAnalysis, err := ea.deadPixelAnalyzer.AnalyzeDeadPixels(ctx, filePath)
236239
if err != nil {
237240
// Log error but don't fail entire analysis - analysis may fail on some video types
238-
fmt.Printf("Warning: dead pixel analysis failed: %v\n", err)
241+
ea.logger.Warn().Err(err).Msg("dead pixel analysis failed")
239242
} else {
240243
result.EnhancedAnalysis.DeadPixelAnalysis = deadPixelAnalysis
241244
}
@@ -246,7 +249,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
246249
pseAnalysis, err := ea.pseAnalyzer.AnalyzePSERisk(ctx, filePath)
247250
if err != nil {
248251
// Log error but don't fail entire analysis - analysis may fail on some video types
249-
fmt.Printf("Warning: PSE analysis failed: %v\n", err)
252+
ea.logger.Warn().Err(err).Msg("PSE analysis failed")
250253
} else {
251254
result.EnhancedAnalysis.PSEAnalysis = pseAnalysis
252255
}
@@ -257,7 +260,7 @@ func (ea *EnhancedAnalyzer) AnalyzeResultWithAdvancedQC(ctx context.Context, res
257260
dispositionAnalysis, err := ea.streamDispositionAnalyzer.AnalyzeStreamDisposition(ctx, filePath, result.Streams)
258261
if err != nil {
259262
// Log error but don't fail entire analysis
260-
fmt.Printf("Warning: stream disposition analysis failed: %v\n", err)
263+
ea.logger.Warn().Err(err).Msg("stream disposition analysis failed")
261264
} else {
262265
result.EnhancedAnalysis.StreamDispositionAnalysis = dispositionAnalysis
263266
}

0 commit comments

Comments
 (0)