Skip to content

Commit 5c88000

Browse files
committed
mcv: binary cache extraction
Signed-off-by: Maryam Tahhan <mtahhan@redhat.com>
1 parent c27558c commit 5c88000

File tree

1 file changed

+72
-9
lines changed

1 file changed

+72
-9
lines changed

mcv/pkg/preflightcheck/vllm.go

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import (
99
"github.com/redhat-et/MCU/mcv/pkg/cache"
1010
)
1111

12-
// CompareVLLMCacheManifestToGPU compares VLLM manifest entries to GPU info using Triton comparison logic
12+
// CompareVLLMCacheManifestToGPU compares VLLM manifest entries to GPU info
13+
// Handles both triton cache (legacy) and binary cache (new) formats
1314
func CompareVLLMCacheManifestToGPU(manifestPath string, devInfo []devices.TritonGPUInfo) error {
1415
data, err := os.ReadFile(manifestPath)
1516
if err != nil {
@@ -22,16 +23,78 @@ func CompareVLLMCacheManifestToGPU(manifestPath string, devInfo []devices.Triton
2223
}
2324

2425
for _, entry := range manifest.VLLM {
25-
convertedEntries := make([]cache.TritonCacheMetadata, len(entry.TritonCacheEntries))
26-
for i, e := range entry.TritonCacheEntries {
27-
if metadata, ok := e.(cache.TritonCacheMetadata); ok {
28-
convertedEntries[i] = metadata
29-
} else {
30-
return fmt.Errorf("failed to assert type cache.TritonCacheMetadata for entry: %v", e)
26+
// Check if this is a binary cache format
27+
if entry.CacheFormat == "binary" && len(entry.BinaryCacheEntries) > 0 {
28+
if err := compareBinaryCacheEntriesToGPU(entry.BinaryCacheEntries, devInfo); err != nil {
29+
return err
3130
}
31+
} else if len(entry.TritonCacheEntries) > 0 {
32+
// Handle triton cache format (legacy)
33+
convertedEntries := make([]cache.TritonCacheMetadata, len(entry.TritonCacheEntries))
34+
for i, e := range entry.TritonCacheEntries {
35+
if metadata, ok := e.(cache.TritonCacheMetadata); ok {
36+
convertedEntries[i] = metadata
37+
} else {
38+
return fmt.Errorf("failed to assert type cache.TritonCacheMetadata for entry: %v", e)
39+
}
40+
}
41+
if err := CompareTritonEntriesToGPU(convertedEntries, devInfo); err != nil {
42+
return err
43+
}
44+
}
45+
}
46+
47+
return nil
48+
}
49+
50+
// compareBinaryCacheEntriesToGPU validates binary cache entries against GPU hardware
51+
func compareBinaryCacheEntriesToGPU(entries []cache.BinaryCacheMetadata, devInfo []devices.TritonGPUInfo) error {
52+
for _, entry := range entries {
53+
// Extract hardware info from the binary cache metadata
54+
backend := entry.TargetDevice
55+
if backend == "" {
56+
backend = "cuda" // Default if not specified
57+
}
58+
59+
// Determine arch and warpSize based on backend and env vars
60+
arch := "unknown"
61+
warpSize := 32 // Default for CUDA
62+
63+
switch backend {
64+
case "rocm", "hip":
65+
warpSize = 64 // AMD GPUs use 64-wide wavefronts
66+
// Try to extract GPU architecture from env
67+
if env, ok := entry.Env["VLLM_ROCM_CUSTOM_PAGED_ATTN"]; ok && env != nil {
68+
arch = "gfx90a" // Common MI250/MI300 arch, could be extracted more precisely
69+
}
70+
case "cuda":
71+
// Try to extract CUDA architecture
72+
if mainVersion, ok := entry.Env["VLLM_MAIN_CUDA_VERSION"]; ok {
73+
if version, ok := mainVersion.(string); ok {
74+
arch = "sm_" + version
75+
}
76+
}
77+
case "tpu":
78+
warpSize = 128 // TPU uses different parallelism model
79+
case "cpu":
80+
warpSize = 1 // CPU doesn't have warp concept
3281
}
33-
if err := CompareTritonEntriesToGPU(convertedEntries, devInfo); err != nil {
34-
return err
82+
83+
// Check if any GPU matches this binary cache entry
84+
matched := false
85+
for _, gpu := range devInfo {
86+
backendMatches := backend == gpu.Backend
87+
archMatches := arch == gpu.Arch
88+
warpMatches := warpSize == gpu.WarpSize
89+
90+
if backendMatches && archMatches && warpMatches {
91+
matched = true
92+
break
93+
}
94+
}
95+
96+
if !matched {
97+
return fmt.Errorf("binary cache entry (backend=%s, arch=%s, warpSize=%d) does not match any available GPU", backend, arch, warpSize)
3598
}
3699
}
37100

0 commit comments

Comments
 (0)