20
20
PoolingMode ,
21
21
rounded_row_size_in_bytes ,
22
22
)
23
+ from fbgemm_gpu .tbe .cache .kv_embedding_ops_inference import KVEmbeddingInference
23
24
from torchrec .distributed .batched_embedding_kernel import (
24
25
BaseBatchedEmbedding ,
25
26
BaseBatchedEmbeddingBag ,
@@ -237,13 +238,16 @@ def __init__(
237
238
super ().__init__ (config , pg , device )
238
239
239
240
managed : List [EmbeddingLocation ] = []
241
+ is_virtual_table : bool = False
240
242
for table in config .embedding_tables :
241
243
if device is not None and device .type == "cuda" :
242
244
managed .append (
243
245
compute_kernel_to_embedding_location (table .compute_kernel )
244
246
)
245
247
else :
246
248
managed .append (EmbeddingLocation .HOST )
249
+ if table .use_virtual_table :
250
+ is_virtual_table = True
247
251
self ._config : GroupedEmbeddingConfig = config
248
252
self ._emb_module_registered : bool = is_fused_param_register_tbe (fused_params )
249
253
self ._is_weighted : Optional [bool ] = config .is_weighted
@@ -284,6 +288,8 @@ def __init__(
284
288
285
289
if self .lengths_to_tbe :
286
290
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
291
+ elif is_virtual_table :
292
+ tbe_clazz = KVEmbeddingInference
287
293
else :
288
294
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen
289
295
@@ -448,13 +454,16 @@ def __init__(
448
454
super ().__init__ (config , pg , device )
449
455
450
456
managed : List [EmbeddingLocation ] = []
457
+ is_virtual_table = False
451
458
for table in config .embedding_tables :
452
459
if device is not None and device .type == "cuda" :
453
460
managed .append (
454
461
compute_kernel_to_embedding_location (table .compute_kernel )
455
462
)
456
463
else :
457
464
managed .append (EmbeddingLocation .HOST )
465
+ if table .use_virtual_table :
466
+ is_virtual_table = True
458
467
self ._config : GroupedEmbeddingConfig = config
459
468
self ._emb_module_registered : bool = is_fused_param_register_tbe (fused_params )
460
469
self ._quant_state_dict_split_scale_bias : bool = (
@@ -465,37 +474,40 @@ def __init__(
465
474
)
466
475
# 16 for CUDA, 1 for others like CPU and MTIA.
467
476
self ._tbe_row_alignment : int = 16 if self ._runtime_device .type == "cuda" else 1
468
- self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = (
469
- IntNBitTableBatchedEmbeddingBagsCodegen (
470
- embedding_specs = [
477
+ embedding_clazz = (
478
+ KVEmbeddingInference
479
+ if is_virtual_table
480
+ else IntNBitTableBatchedEmbeddingBagsCodegen
481
+ )
482
+ self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz (
483
+ embedding_specs = [
484
+ (
485
+ table .name ,
486
+ local_rows ,
471
487
(
472
- table .name ,
473
- local_rows ,
474
- (
475
- local_cols
476
- if self ._quant_state_dict_split_scale_bias
477
- else table .embedding_dim
478
- ),
479
- data_type_to_sparse_type (table .data_type ),
480
- location ,
481
- )
482
- for local_rows , local_cols , table , location in zip (
483
- self ._local_rows ,
484
- self ._local_cols ,
485
- config .embedding_tables ,
486
- managed ,
487
- )
488
- ],
489
- device = device ,
490
- pooling_mode = PoolingMode .NONE ,
491
- feature_table_map = self ._feature_table_map ,
492
- row_alignment = self ._tbe_row_alignment ,
493
- uvm_host_mapped = True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
494
- feature_names_per_table = [
495
- table .feature_names for table in config .embedding_tables
496
- ],
497
- ** (tbe_fused_params (fused_params ) or {}),
498
- )
488
+ local_cols
489
+ if self ._quant_state_dict_split_scale_bias
490
+ else table .embedding_dim
491
+ ),
492
+ data_type_to_sparse_type (table .data_type ),
493
+ location ,
494
+ )
495
+ for local_rows , local_cols , table , location in zip (
496
+ self ._local_rows ,
497
+ self ._local_cols ,
498
+ config .embedding_tables ,
499
+ managed ,
500
+ )
501
+ ],
502
+ device = device ,
503
+ pooling_mode = PoolingMode .NONE ,
504
+ feature_table_map = self ._feature_table_map ,
505
+ row_alignment = self ._tbe_row_alignment ,
506
+ uvm_host_mapped = True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
507
+ feature_names_per_table = [
508
+ table .feature_names for table in config .embedding_tables
509
+ ],
510
+ ** (tbe_fused_params (fused_params ) or {}),
499
511
)
500
512
if device is not None :
501
513
self ._emb_module .initialize_weights ()
0 commit comments