33
44import torch
55
6+ from vllm .platforms import current_platform
67from vllm .utils import seed_everything
78
89
@@ -28,4 +29,25 @@ def set_weight_attrs(
2829 for key , value in weight_attrs .items ():
2930 assert not hasattr (
3031 weight , key ), (f"Overwriting existing tensor attribute: { key } " )
32+
33+ # NOTE(woosuk): During weight loading, we often do something like:
34+ # narrowed_tensor = param.data.narrow(0, offset, len)
35+ # narrowed_tensor.copy_(real_weight)
36+ # expecting narrowed_tensor and param.data to share the same storage.
37+ # However, on TPUs, narrowed_tensor will lazily propagate to the base
38+ # tensor, which is param.data, leading to the redundant memory usage.
39+ # This sometimes causes OOM errors during model loading. To avoid this,
40+ # we sync the param tensor after its weight loader is called.
41+ # TODO(woosuk): Remove this hack once we have a better solution.
42+ if current_platform .is_tpu () and key == "weight_loader" :
43+ value = _make_synced_weight_loader (value )
3144 setattr (weight , key , value )
45+
46+
47+ def _make_synced_weight_loader (original_weight_loader ):
48+
49+ def _synced_weight_loader (param , * args , ** kwargs ):
50+ original_weight_loader (param , * args , ** kwargs )
51+ torch ._sync (param )
52+
53+ return _synced_weight_loader
0 commit comments