Skip to content

Commit 85ae1e4

Browse files
committed
Add boilerplate for sharing MM/deepstack embeddings across TP model
1 parent c965413 commit 85ae1e4

File tree

4 files changed

+110
-9
lines changed

4 files changed

+110
-9
lines changed

exllamav3/model/model_tp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .model_tp_fn import *
1313
import uuid
1414
from ..util import log_tp, global_t0
15+
from ..tokenizer.mm_embedding import send_embeddings
1516

1617
cleanupper = Cleanupper()
1718

@@ -341,7 +342,6 @@ def unload_tp(self):
341342
cleanupper.unregister_atexit(self.destroy_tp_context)
342343

343344

344-
345345
def prepare_inputs_for_tp(self, x: torch.Tensor, params: dict) -> torch.Tensor:
346346
self.tp_producer.clear()
347347
# Use ID of Cache object as reference to avoid having to pickle it
@@ -353,11 +353,15 @@ def prepare_inputs_for_tp(self, x: torch.Tensor, params: dict) -> torch.Tensor:
353353
"cache_seqlens",
354354
"positions",
355355
"position_ids",
356-
# "indexed_embeddings",
357356
]:
358357
p = params.get(tensor_param)
359358
if p is not None:
360359
params[tensor_param] = self.tp_producer.send(p)
360+
361+
p = params.get("indexed_embeddings")
362+
if p is not None:
363+
params["indexed_embeddings"] = send_embeddings(self.tp_producer, p)
364+
361365
return self.tp_producer.send(x)
362366

363367

exllamav3/model/model_tp_fn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..ext import exllamav3_ext as ext
55
from functools import lru_cache
66
from .model_tp_backend import TPBackendNCCL, TPBackendNative
7+
from ..tokenizer.mm_embedding import recv_embeddings
78
from ..util import log_tp, set_t0
89

910
def init_pg(device: int, active_devices: list[int], output_device: int, backend_args: dict, master: bool = False):
@@ -191,12 +192,15 @@ def mp_model_forward(
191192
"cache_seqlens",
192193
"positions",
193194
"position_ids",
194-
# "indexed_embeddings",
195195
]:
196196
p = params.get(tensor_param)
197197
if p is not None:
198198
params[tensor_param] = consumer.recv(p, cuda = True)
199199

200+
p = params.get("indexed_embeddings")
201+
if p is not None:
202+
params["indexed_embeddings"] = recv_embeddings(consumer, p)
203+
200204
params["backend"] = backend
201205

202206
x = consumer.recv(shared_input)

exllamav3/model/model_tp_shared.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .model_tp_cuda import cuda_host_register, cuda_host_unregister, CUDA_HOST_REGISTER_PORTABLE
77

88
DEFAULT_BUFFER_SIZE = 2 * 1024 ** 3
9+
MAX_CACHE_PER_PROCESS = 4 * 1024**3
910

1011
_torch_dtypes = {
1112
"torch.uint8": torch.uint8,
@@ -37,13 +38,17 @@ def __init__(
3738
# Pre-touch buffer to avoid page faults later
3839
self.buf[: self.buffer_size: 4096] = 0
3940

41+
# Cache
42+
self.cached_cpu_tensors = {}
43+
self.cache_size = 0
44+
4045
def export(self):
4146
return {
4247
"shm_name": self.shm_name,
4348
"buffer_size": self.buffer_size,
4449
}
4550

46-
def send(self, tensor: torch.Tensor | None) -> dict:
51+
def send(self, tensor: torch.Tensor | None, cache_id: int = None) -> dict:
4752

4853
# None tensor
4954
if tensor is None:
@@ -74,13 +79,30 @@ def send(self, tensor: torch.Tensor | None) -> dict:
7479
dst = np.ndarray((nbytes,), dtype = np.uint8, buffer = self.shm.buf, offset = offset)
7580
np.copyto(dst, src, casting = "no")
7681

82+
# Cache
83+
if nbytes > MAX_CACHE_PER_PROCESS:
84+
cache_id = None
85+
86+
if cache_id is not None:
87+
if cache_id in self.cached_cpu_tensors:
88+
# print("sending cache ref:", cache_id)
89+
return {
90+
"method": "cached",
91+
"cache_id": cache_id,
92+
}
93+
while self.cache_size + nbytes > MAX_CACHE_PER_PROCESS:
94+
self.cached_cpu_tensors.pop(next(iter(self.cached_cpu_tensors)))
95+
self.cached_cpu_tensors[cache_id] = tensor
96+
# print("caching send:", cache_id)
97+
7798
# Data is now buffered in shared memory space, store metadata and offset
7899
return {
79100
"method": "buffer",
80101
"offset": offset,
81102
"nbytes": nbytes,
82103
"dtype": str(tensor.dtype),
83104
"shape": tuple(tensor.shape),
105+
"cache_id": cache_id
84106
}
85107

86108
def clear(self):
@@ -142,6 +164,11 @@ def get_local_tensor(shm_buf, _buffer_size):
142164
cuda_host_register(self.arena.data_ptr(), self.arena.numel(), flags = CUDA_HOST_REGISTER_PORTABLE)
143165
self.producer.buf_is_pinned = True
144166

167+
# Cache
168+
self.cached_cpu_tensors = {}
169+
self.cache_size = 0
170+
171+
145172
def recv(
146173
self,
147174
imp: dict,
@@ -158,6 +185,13 @@ def recv(
158185
if imp["method"] == "none_tensor":
159186
return None
160187

188+
# Send was cached
189+
cache_id = imp["cache_id"]
190+
if imp["method"] == "cached":
191+
# print("receiving cached:", cache_id)
192+
assert not cuda, "Cannot share cached tensor for CUDA"
193+
return self.cached_cpu_tensors[imp["cache_id"]]
194+
161195
# Fallback method
162196
if imp["method"] == "share_memory":
163197
tensor = imp["shared_tensor"]
@@ -169,6 +203,12 @@ def recv(
169203
dtype = _torch_dtypes[imp["dtype"]]
170204
shape = imp["shape"]
171205
tensor = self.arena.narrow(0, offset, nbytes).view(dtype).view(shape)
206+
if cache_id is not None:
207+
# print("caching recv:", cache_id)
208+
assert not cuda, "Cannot share cached tensor for CUDA"
209+
while self.cache_size + nbytes > MAX_CACHE_PER_PROCESS:
210+
self.cached_cpu_tensors.pop(next(iter(self.cached_cpu_tensors)))
211+
self.cached_cpu_tensors[cache_id] = tensor.clone(memory_format = torch.contiguous_format)
172212

173213
# Slice before cloning
174214
if slice_dim is not None:
@@ -182,7 +222,7 @@ def recv(
182222
copy = True,
183223
memory_format = torch.contiguous_format
184224
)
185-
else:
225+
elif imp["method"] != "share_memory" or not tensor.is_contiguous():
186226
tensor = tensor.clone(memory_format = torch.contiguous_format)
187227

188228
return tensor

exllamav3/tokenizer/mm_embedding.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ class MMEmbedding:
2828

2929
def __init__(
3030
self,
31-
embeddings: torch.Tensor,
32-
token_string: torch.Tensor,
31+
embeddings: torch.Tensor | None = None,
32+
token_string: torch.Tensor | None = None,
3333
text_alias: str | None = None,
3434
deepstack_embeddings: list[torch.Tensor] | None = None,
3535
grid_thw: tuple | None = None,
36-
mrope_merge_size: int | None = None
36+
mrope_merge_size: int | None = None,
37+
imp: dict | None = None
3738
):
3839
"""
3940
:param embeddings:
@@ -46,6 +47,21 @@ def __init__(
4647
Text string to represent this embedding for tokenizing
4748
"""
4849

50+
if imp:
51+
self.metadata = imp["metadata"]
52+
self.full_length = imp["full_length"]
53+
self.mm_length = imp["mm_length"]
54+
self.first_index = imp["first_index"]
55+
self.last_index = imp["last_index"]
56+
self.text_alias = imp["text_alias"]
57+
self.grid_thw = imp["grid_thw"]
58+
self.mrope_merge_size = imp["mrope_merge_size"]
59+
self.embeddings = imp["embeddings"]
60+
self.deepstack_embeddings = imp["deepstack_embeddings"]
61+
self.token_string = None
62+
self.token_list = None
63+
return
64+
4965
global global_allocator
5066

5167
if deepstack_embeddings is not None:
@@ -65,8 +81,45 @@ def __init__(
6581
self.grid_thw = grid_thw
6682
self.mrope_merge_size = mrope_merge_size
6783

84+
# not exported for TP
6885
r = torch.arange(self.first_index, self.first_index + self.mm_length, dtype = torch.long)
6986
m = (token_string == -1)
7087
token_string.masked_scatter_(m, r)
7188
self.token_string = token_string
72-
self.token_list = token_string[0].tolist()
89+
self.token_list = token_string[0].tolist()
90+
91+
92+
def send_embeddings(producer, ies: list[MMEmbedding]):
93+
return {
94+
"method": "list",
95+
"data": [
96+
{
97+
"metadata": ie.metadata,
98+
"full_length": ie.full_length,
99+
"mm_length": ie.mm_length,
100+
"first_index": ie.first_index,
101+
"last_index": ie.last_index,
102+
"text_alias": ie.text_alias,
103+
"grid_thw": ie.grid_thw,
104+
"mrope_merge_size": ie.mrope_merge_size,
105+
"embeddings": producer.send(ie.embeddings, cache_id = id(ie.embeddings)),
106+
"deepstack_embeddings": [
107+
producer.send(dse, cache_id = id(dse))
108+
for dse in ie.deepstack_embeddings
109+
] if ie.deepstack_embeddings is not None else None
110+
}
111+
for ie in ies
112+
]
113+
}
114+
115+
116+
def recv_embeddings(consumer, recv) -> list[MMEmbedding]:
117+
result = []
118+
assert recv.get("method") == "list", "Consumer expected list"
119+
for imp in recv["data"]:
120+
imp["embeddings"] = consumer.recv(imp["embeddings"])
121+
imp["deepstack_embeddings"] = [
122+
consumer.recv(dse) for dse in imp["deepstack_embeddings"]
123+
] if imp.get("deepstack_embeddings") else None
124+
result.append(MMEmbedding(imp = imp))
125+
return result

0 commit comments

Comments
 (0)