99 "github.com/redhat-et/MCU/mcv/pkg/cache"
1010)
1111
12- // CompareVLLMCacheManifestToGPU compares VLLM manifest entries to GPU info using Triton comparison logic
12+ // CompareVLLMCacheManifestToGPU compares VLLM manifest entries to GPU info
13+ // Handles both triton cache (legacy) and binary cache (new) formats
1314func CompareVLLMCacheManifestToGPU (manifestPath string , devInfo []devices.TritonGPUInfo ) error {
1415 data , err := os .ReadFile (manifestPath )
1516 if err != nil {
@@ -22,16 +23,78 @@ func CompareVLLMCacheManifestToGPU(manifestPath string, devInfo []devices.Triton
2223 }
2324
2425 for _ , entry := range manifest .VLLM {
25- convertedEntries := make ([]cache.TritonCacheMetadata , len (entry .TritonCacheEntries ))
26- for i , e := range entry .TritonCacheEntries {
27- if metadata , ok := e .(cache.TritonCacheMetadata ); ok {
28- convertedEntries [i ] = metadata
29- } else {
30- return fmt .Errorf ("failed to assert type cache.TritonCacheMetadata for entry: %v" , e )
26+ // Check if this is a binary cache format
27+ if entry .CacheFormat == "binary" && len (entry .BinaryCacheEntries ) > 0 {
28+ if err := compareBinaryCacheEntriesToGPU (entry .BinaryCacheEntries , devInfo ); err != nil {
29+ return err
3130 }
31+ } else if len (entry .TritonCacheEntries ) > 0 {
32+ // Handle triton cache format (legacy)
33+ convertedEntries := make ([]cache.TritonCacheMetadata , len (entry .TritonCacheEntries ))
34+ for i , e := range entry .TritonCacheEntries {
35+ if metadata , ok := e .(cache.TritonCacheMetadata ); ok {
36+ convertedEntries [i ] = metadata
37+ } else {
38+ return fmt .Errorf ("failed to assert type cache.TritonCacheMetadata for entry: %v" , e )
39+ }
40+ }
41+ if err := CompareTritonEntriesToGPU (convertedEntries , devInfo ); err != nil {
42+ return err
43+ }
44+ }
45+ }
46+
47+ return nil
48+ }
49+
50+ // compareBinaryCacheEntriesToGPU validates binary cache entries against GPU hardware
51+ func compareBinaryCacheEntriesToGPU (entries []cache.BinaryCacheMetadata , devInfo []devices.TritonGPUInfo ) error {
52+ for _ , entry := range entries {
53+ // Extract hardware info from the binary cache metadata
54+ backend := entry .TargetDevice
55+ if backend == "" {
56+ backend = "cuda" // Default if not specified
57+ }
58+
59+ // Determine arch and warpSize based on backend and env vars
60+ arch := "unknown"
61+ warpSize := 32 // Default for CUDA
62+
63+ switch backend {
64+ case "rocm" , "hip" :
65+ warpSize = 64 // AMD GPUs use 64-wide wavefronts
66+ // Try to extract GPU architecture from env
67+ if env , ok := entry .Env ["VLLM_ROCM_CUSTOM_PAGED_ATTN" ]; ok && env != nil {
68+ arch = "gfx90a" // Common MI250/MI300 arch, could be extracted more precisely
69+ }
70+ case "cuda" :
71+ // Try to extract CUDA architecture
72+ if mainVersion , ok := entry .Env ["VLLM_MAIN_CUDA_VERSION" ]; ok {
73+ if version , ok := mainVersion .(string ); ok {
74+ arch = "sm_" + version
75+ }
76+ }
77+ case "tpu" :
78+ warpSize = 128 // TPU uses different parallelism model
79+ case "cpu" :
80+ warpSize = 1 // CPU doesn't have warp concept
3281 }
33- if err := CompareTritonEntriesToGPU (convertedEntries , devInfo ); err != nil {
34- return err
82+
83+ // Check if any GPU matches this binary cache entry
84+ matched := false
85+ for _ , gpu := range devInfo {
86+ backendMatches := backend == gpu .Backend
87+ archMatches := arch == gpu .Arch
88+ warpMatches := warpSize == gpu .WarpSize
89+
90+ if backendMatches && archMatches && warpMatches {
91+ matched = true
92+ break
93+ }
94+ }
95+
96+ if ! matched {
97+ return fmt .Errorf ("binary cache entry (backend=%s, arch=%s, warpSize=%d) does not match any available GPU" , backend , arch , warpSize )
3598 }
3699 }
37100
0 commit comments