@@ -23,33 +23,10 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
23
23
device : Union [torch .device , str ], ** kwargs ):
24
24
# Increasing max_num_batched_tokens by 3x to handle increase in
25
25
# tensor size due to padding.
26
+ # TODO: Need to check if this override is still required
26
27
PunicaWrapperBase .__init__ (self , 3 * max_num_batched_tokens ,
27
28
max_batches , device )
28
29
29
- def _update_base_metadata (
30
- self ,
31
- mapping : "LoRAMapping" ,
32
- lora_index_to_id : list [Optional [int ]],
33
- max_loras : int ,
34
- vocab_size : int ,
35
- extra_vocab_size : int ,
36
- ):
37
- (
38
- base_indices ,
39
- sampler_indices ,
40
- sampler_indices_padded ,
41
- embeddings_indices ,
42
- indices_len ,
43
- ) = convert_mapping (mapping , lora_index_to_id , max_loras , vocab_size ,
44
- extra_vocab_size , self .device )
45
- self ._token_lora_indices [:base_indices .shape [0 ]].copy_ (base_indices )
46
- self ._sampler_indices [:sampler_indices .shape [0 ]].copy_ (sampler_indices )
47
- self ._sampler_indices_padded [:sampler_indices_padded .shape [0 ]].copy_ (
48
- sampler_indices_padded )
49
- self ._embeddings_indices [:embeddings_indices .
50
- shape [0 ], :embeddings_indices .shape [1 ]].copy_ (
51
- embeddings_indices )
52
- self .indices_len [:] = indices_len
53
30
54
31
def add_lora_embedding (self ,
55
32
y : torch .Tensor ,
0 commit comments