|
| 1 | +package workloads |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "encoding/json" |
| 6 | + "fmt" |
| 7 | + "os" |
| 8 | + "path/filepath" |
| 9 | + "strings" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/adrg/xdg" |
| 13 | + "github.com/gofrs/flock" |
| 14 | + |
| 15 | + "github.com/stacklok/toolhive/pkg/logger" |
| 16 | +) |
| 17 | + |
| 18 | +const ( |
| 19 | + // statusesPrefix is the prefix used for status files in the XDG data directory |
| 20 | + statusesPrefix = "statuses" |
| 21 | + // lockTimeout is the maximum time to wait for a file lock |
| 22 | + lockTimeout = 1 * time.Second |
| 23 | + // lockRetryInterval is the interval between lock attempts |
| 24 | + lockRetryInterval = 100 * time.Millisecond |
| 25 | +) |
| 26 | + |
| 27 | +// NewFileStatusManager creates a new file-based StatusManager. |
| 28 | +// Status files will be stored in the XDG data directory under "statuses/". |
| 29 | +func NewFileStatusManager() StatusManager { |
| 30 | + // Get the base directory using XDG data directory |
| 31 | + baseDir, err := xdg.DataFile(statusesPrefix) |
| 32 | + if err != nil { |
| 33 | + // Fallback to a basic path if XDG fails |
| 34 | + baseDir = filepath.Join(os.TempDir(), "toolhive", statusesPrefix) |
| 35 | + } |
| 36 | + // Remove the filename part to get just the directory |
| 37 | + baseDir = filepath.Dir(baseDir) |
| 38 | + |
| 39 | + return &fileStatusManager{ |
| 40 | + baseDir: baseDir, |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +// fileStatusManager is an implementation of StatusManager that persists |
| 45 | +// workload status to files on disk with JSON serialization and file locking |
| 46 | +// to prevent concurrent access issues. |
| 47 | +type fileStatusManager struct { |
| 48 | + baseDir string |
| 49 | +} |
| 50 | + |
| 51 | +// workloadStatusFile represents the JSON structure stored on disk |
| 52 | +type workloadStatusFile struct { |
| 53 | + Status WorkloadStatus `json:"status"` |
| 54 | + StatusContext string `json:"status_context,omitempty"` |
| 55 | + CreatedAt time.Time `json:"created_at"` |
| 56 | + UpdatedAt time.Time `json:"updated_at"` |
| 57 | +} |
| 58 | + |
| 59 | +// CreateWorkloadStatus creates the initial `starting` status for a new workload. |
| 60 | +// It will return an error if the workload already exists. |
| 61 | +func (f *fileStatusManager) CreateWorkloadStatus(ctx context.Context, workloadName string) error { |
| 62 | + return f.withFileLock(ctx, workloadName, func(statusFilePath string) error { |
| 63 | + // Check if file already exists |
| 64 | + if _, err := os.Stat(statusFilePath); err == nil { |
| 65 | + return fmt.Errorf("workload %s already exists", workloadName) |
| 66 | + } else if !os.IsNotExist(err) { |
| 67 | + return fmt.Errorf("failed to check if workload %s exists: %w", workloadName, err) |
| 68 | + } |
| 69 | + |
| 70 | + // Create initial status |
| 71 | + now := time.Now() |
| 72 | + statusFile := workloadStatusFile{ |
| 73 | + Status: WorkloadStatusStarting, |
| 74 | + StatusContext: "", |
| 75 | + CreatedAt: now, |
| 76 | + UpdatedAt: now, |
| 77 | + } |
| 78 | + |
| 79 | + if err := f.writeStatusFile(statusFilePath, statusFile); err != nil { |
| 80 | + return fmt.Errorf("failed to write status file for workload %s: %w", workloadName, err) |
| 81 | + } |
| 82 | + |
| 83 | + logger.Debugf("workload %s created with starting status", workloadName) |
| 84 | + return nil |
| 85 | + }) |
| 86 | +} |
| 87 | + |
| 88 | +// GetWorkloadStatus retrieves the status of a workload by its name. |
| 89 | +func (f *fileStatusManager) GetWorkloadStatus(ctx context.Context, workloadName string) (*WorkloadStatus, string, error) { |
| 90 | + var result *WorkloadStatus |
| 91 | + var statusContext string |
| 92 | + |
| 93 | + err := f.withFileReadLock(ctx, workloadName, func(statusFilePath string) error { |
| 94 | + // Check if file exists |
| 95 | + if _, err := os.Stat(statusFilePath); os.IsNotExist(err) { |
| 96 | + return fmt.Errorf("workload %s not found", workloadName) |
| 97 | + } else if err != nil { |
| 98 | + return fmt.Errorf("failed to check status file for workload %s: %w", workloadName, err) |
| 99 | + } |
| 100 | + |
| 101 | + statusFile, err := f.readStatusFile(statusFilePath) |
| 102 | + if err != nil { |
| 103 | + return fmt.Errorf("failed to read status for workload %s: %w", workloadName, err) |
| 104 | + } |
| 105 | + |
| 106 | + result = &statusFile.Status |
| 107 | + statusContext = statusFile.StatusContext |
| 108 | + return nil |
| 109 | + }) |
| 110 | + |
| 111 | + return result, statusContext, err |
| 112 | +} |
| 113 | + |
| 114 | +// SetWorkloadStatus sets the status of a workload by its name. |
| 115 | +// This method will do nothing if the workload does not exist, following the interface contract. |
| 116 | +func (f *fileStatusManager) SetWorkloadStatus( |
| 117 | + ctx context.Context, workloadName string, status WorkloadStatus, contextMsg string, |
| 118 | +) { |
| 119 | + err := f.withFileLock(ctx, workloadName, func(statusFilePath string) error { |
| 120 | + // Check if file exists |
| 121 | + if _, err := os.Stat(statusFilePath); os.IsNotExist(err) { |
| 122 | + // File doesn't exist, do nothing as per interface contract |
| 123 | + logger.Debugf("workload %s does not exist, skipping status update", workloadName) |
| 124 | + return nil |
| 125 | + } else if err != nil { |
| 126 | + return fmt.Errorf("failed to check status file for workload %s: %w", workloadName, err) |
| 127 | + } |
| 128 | + |
| 129 | + // Read existing file to preserve created_at timestamp |
| 130 | + statusFile, err := f.readStatusFile(statusFilePath) |
| 131 | + if err != nil { |
| 132 | + return fmt.Errorf("failed to read existing status for workload %s: %w", workloadName, err) |
| 133 | + } |
| 134 | + |
| 135 | + // Update status and context |
| 136 | + statusFile.Status = status |
| 137 | + statusFile.StatusContext = contextMsg |
| 138 | + statusFile.UpdatedAt = time.Now() |
| 139 | + |
| 140 | + if err := f.writeStatusFile(statusFilePath, *statusFile); err != nil { |
| 141 | + return fmt.Errorf("failed to write updated status for workload %s: %w", workloadName, err) |
| 142 | + } |
| 143 | + |
| 144 | + logger.Debugf("workload %s set to status %s (context: %s)", workloadName, status, contextMsg) |
| 145 | + return nil |
| 146 | + }) |
| 147 | + |
| 148 | + if err != nil { |
| 149 | + logger.Errorf("error updating workload %s status: %v", workloadName, err) |
| 150 | + } |
| 151 | +} |
| 152 | + |
| 153 | +// DeleteWorkloadStatus removes the status of a workload by its name. |
| 154 | +func (f *fileStatusManager) DeleteWorkloadStatus(ctx context.Context, workloadName string) error { |
| 155 | + return f.withFileLock(ctx, workloadName, func(statusFilePath string) error { |
| 156 | + // Remove status file |
| 157 | + if err := os.Remove(statusFilePath); err != nil && !os.IsNotExist(err) { |
| 158 | + return fmt.Errorf("failed to delete status file for workload %s: %w", workloadName, err) |
| 159 | + } |
| 160 | + |
| 161 | + // Remove lock file (best effort) - done by withFileLock after this function returns |
| 162 | + logger.Debugf("workload %s status deleted", workloadName) |
| 163 | + return nil |
| 164 | + }) |
| 165 | +} |
| 166 | + |
| 167 | +// getStatusFilePath returns the file path for a given workload's status file. |
| 168 | +func (f *fileStatusManager) getStatusFilePath(workloadName string) string { |
| 169 | + return filepath.Join(f.baseDir, fmt.Sprintf("%s.json", workloadName)) |
| 170 | +} |
| 171 | + |
| 172 | +// getLockFilePath returns the lock file path for a given workload. |
| 173 | +func (f *fileStatusManager) getLockFilePath(workloadName string) string { |
| 174 | + return filepath.Join(f.baseDir, fmt.Sprintf("%s.lock", workloadName)) |
| 175 | +} |
| 176 | + |
| 177 | +// ensureBaseDir creates the base directory if it doesn't exist. |
| 178 | +func (f *fileStatusManager) ensureBaseDir() error { |
| 179 | + return os.MkdirAll(f.baseDir, 0750) |
| 180 | +} |
| 181 | + |
| 182 | +// withFileLock executes the provided function while holding a write lock on the workload's lock file. |
| 183 | +func (f *fileStatusManager) withFileLock(ctx context.Context, workloadName string, fn func(string) error) error { |
| 184 | + // Validate workload name |
| 185 | + if strings.Contains(workloadName, "..") || strings.ContainsAny(workloadName, "/\\") { |
| 186 | + return fmt.Errorf("invalid workload name '%s': contains forbidden characters", workloadName) |
| 187 | + } |
| 188 | + if err := f.ensureBaseDir(); err != nil { |
| 189 | + return fmt.Errorf("failed to create base directory: %w", err) |
| 190 | + } |
| 191 | + |
| 192 | + statusFilePath := f.getStatusFilePath(workloadName) |
| 193 | + lockFilePath := f.getLockFilePath(workloadName) |
| 194 | + |
| 195 | + // Create file lock |
| 196 | + fileLock := flock.New(lockFilePath) |
| 197 | + defer func() { |
| 198 | + if err := fileLock.Unlock(); err != nil { |
| 199 | + logger.Warnf("failed to unlock file %s: %v", lockFilePath, err) |
| 200 | + } |
| 201 | + // Attempt to remove lock file (best effort) |
| 202 | + if err := os.Remove(lockFilePath); err != nil && !os.IsNotExist(err) { |
| 203 | + logger.Warnf("failed to remove lock file for workload %s: %v", workloadName, err) |
| 204 | + } |
| 205 | + }() |
| 206 | + |
| 207 | + // Create context with timeout |
| 208 | + lockCtx, cancel := context.WithTimeout(ctx, lockTimeout) |
| 209 | + defer cancel() |
| 210 | + |
| 211 | + // Acquire lock with context |
| 212 | + locked, err := fileLock.TryLockContext(lockCtx, lockRetryInterval) |
| 213 | + if err != nil { |
| 214 | + return fmt.Errorf("failed to acquire lock for workload %s: %w", workloadName, err) |
| 215 | + } |
| 216 | + if !locked { |
| 217 | + return fmt.Errorf("could not acquire lock for workload %s: timeout after %v", workloadName, lockTimeout) |
| 218 | + } |
| 219 | + |
| 220 | + return fn(statusFilePath) |
| 221 | +} |
| 222 | + |
| 223 | +// withFileReadLock executes the provided function while holding a read lock on the workload's lock file. |
| 224 | +func (f *fileStatusManager) withFileReadLock(ctx context.Context, workloadName string, fn func(string) error) error { |
| 225 | + statusFilePath := f.getStatusFilePath(workloadName) |
| 226 | + lockFilePath := f.getLockFilePath(workloadName) |
| 227 | + |
| 228 | + // Create file lock |
| 229 | + fileLock := flock.New(lockFilePath) |
| 230 | + defer func() { |
| 231 | + if err := fileLock.Unlock(); err != nil { |
| 232 | + logger.Warnf("failed to unlock file %s: %v", lockFilePath, err) |
| 233 | + } |
| 234 | + }() |
| 235 | + |
| 236 | + // Create context with timeout |
| 237 | + lockCtx, cancel := context.WithTimeout(ctx, lockTimeout) |
| 238 | + defer cancel() |
| 239 | + |
| 240 | + // Acquire read lock with context |
| 241 | + locked, err := fileLock.TryRLockContext(lockCtx, lockRetryInterval) |
| 242 | + if err != nil { |
| 243 | + return fmt.Errorf("failed to acquire read lock for workload %s: %w", workloadName, err) |
| 244 | + } |
| 245 | + if !locked { |
| 246 | + return fmt.Errorf("could not acquire read lock for workload %s: timeout after %v", workloadName, lockTimeout) |
| 247 | + } |
| 248 | + |
| 249 | + return fn(statusFilePath) |
| 250 | +} |
| 251 | + |
| 252 | +// readStatusFile reads and parses a workload status file from disk. |
| 253 | +func (*fileStatusManager) readStatusFile(statusFilePath string) (*workloadStatusFile, error) { |
| 254 | + data, err := os.ReadFile(statusFilePath) //nolint:gosec // file path is constructed by our own function |
| 255 | + if err != nil { |
| 256 | + return nil, fmt.Errorf("failed to read status file: %w", err) |
| 257 | + } |
| 258 | + |
| 259 | + var statusFile workloadStatusFile |
| 260 | + if err := json.Unmarshal(data, &statusFile); err != nil { |
| 261 | + return nil, fmt.Errorf("failed to unmarshal status file: %w", err) |
| 262 | + } |
| 263 | + |
| 264 | + return &statusFile, nil |
| 265 | +} |
| 266 | + |
| 267 | +// writeStatusFile writes a workload status file to disk with proper formatting. |
| 268 | +func (*fileStatusManager) writeStatusFile(statusFilePath string, statusFile workloadStatusFile) error { |
| 269 | + data, err := json.MarshalIndent(statusFile, "", " ") |
| 270 | + if err != nil { |
| 271 | + return fmt.Errorf("failed to marshal status file: %w", err) |
| 272 | + } |
| 273 | + |
| 274 | + if err := os.WriteFile(statusFilePath, data, 0600); err != nil { |
| 275 | + return fmt.Errorf("failed to write status file: %w", err) |
| 276 | + } |
| 277 | + |
| 278 | + return nil |
| 279 | +} |
0 commit comments