|
23 | 23 | from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
24 | 24 | from collections import UserDict, defaultdict
|
25 | 25 | from collections.abc import Iterable, Mapping
|
| 26 | +from dataclasses import dataclass, field |
26 | 27 | from functools import lru_cache, partial, wraps
|
27 | 28 | from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
28 |
| - Dict, Generic, Hashable, List, Literal, Optional, |
29 |
| - OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) |
| 29 | + Dict, Generator, Generic, Hashable, List, Literal, |
| 30 | + Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, |
| 31 | + overload) |
30 | 32 | from uuid import uuid4
|
31 | 33 |
|
32 | 34 | import numpy as np
|
@@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
|
1664 | 1666 | # Finally kill the parent
|
1665 | 1667 | with contextlib.suppress(ProcessLookupError):
|
1666 | 1668 | os.kill(pid, signal.SIGKILL)
|
| 1669 | + |
| 1670 | + |
| 1671 | +@dataclass |
| 1672 | +class MemorySnapshot: |
| 1673 | + """Memory snapshot.""" |
| 1674 | + torch_peak_in_bytes: int = 0 |
| 1675 | + torch_memory_in_bytes: int = 0 |
| 1676 | + timestamp: float = 0.0 |
| 1677 | + |
| 1678 | + def measure(self): |
| 1679 | + self.torch_peak_in_bytes = torch.cuda.memory_stats( |
| 1680 | + )["allocated_bytes.all.peak"] |
| 1681 | + self.torch_memory_in_bytes = torch.cuda.memory_stats( |
| 1682 | + )["allocated_bytes.all.current"] |
| 1683 | + self.timestamp = time.time() |
| 1684 | + |
| 1685 | + def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": |
| 1686 | + """support a - b""" |
| 1687 | + return MemorySnapshot( |
| 1688 | + torch_peak_in_bytes=self.torch_peak_in_bytes - |
| 1689 | + other.torch_peak_in_bytes, |
| 1690 | + torch_memory_in_bytes=self.torch_memory_in_bytes - |
| 1691 | + other.torch_memory_in_bytes, |
| 1692 | + timestamp=self.timestamp - other.timestamp) |
| 1693 | + |
| 1694 | + |
| 1695 | +@dataclass |
| 1696 | +class MemoryProfilingResult: |
| 1697 | + """Memory profiling result. |
| 1698 | + """ # noqa |
| 1699 | + baseline_memory_in_bytes: int = 0 |
| 1700 | + non_kv_cache_memory_in_bytes: int = 0 |
| 1701 | + torch_peak_increase_in_bytes: int = 0 |
| 1702 | + non_torch_increase_in_bytes: int = 0 |
| 1703 | + weights_memory_in_bytes: float = 0 |
| 1704 | + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) |
| 1705 | + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) |
| 1706 | + profile_time: float = 0.0 |
| 1707 | + |
| 1708 | + |
| 1709 | +@contextlib.contextmanager |
| 1710 | +def memory_profiling( |
| 1711 | + baseline_memory_in_bytes: int, weights_memory_in_bytes: int |
| 1712 | +) -> Generator[MemoryProfilingResult, None, None]: |
| 1713 | + """Memory profiling context manager. |
| 1714 | + baseline_memory_in_bytes: memory used by all the components other than |
| 1715 | + the current vLLM instance. It contains: memory used by other processes, memory |
| 1716 | + used by another vLLM instance in the same process, etc. It is usually measured |
| 1717 | + before the current vLLM instance initialize the device. And we assume it is |
| 1718 | + constant during the profiling of the current vLLM instance. |
| 1719 | + weights_memory_in_bytes: memory used by PyTorch when loading the model weights. |
| 1720 | + Note that, before loading the model weights, we also initialize the device |
| 1721 | + and distributed environment, which may consume some memory. This part is not |
| 1722 | + included in the weights_memory_in_bytes because PyTorch does not control it. |
| 1723 | +
|
| 1724 | + The memory in one GPU can be classified into 3 categories: |
| 1725 | + 1. memory used by anything other than the current vLLM instance. |
| 1726 | + 2. memory used by torch in the current vLLM instance. |
| 1727 | + 3. memory used in the current vLLM instance, but not by torch. |
| 1728 | +
|
| 1729 | + A quantitive example: |
| 1730 | +
|
| 1731 | + Before creating the current vLLM instance: |
| 1732 | + category 1: 1 GiB |
| 1733 | + category 2: 0 GiB |
| 1734 | + category 3: 0 GiB |
| 1735 | +
|
| 1736 | + After creating the current vLLM instance and loading the model, |
| 1737 | + (i.e. before profiling): |
| 1738 | + category 1: 1 GiB |
| 1739 | + category 2: 2 GiB (model weights take 2 GiB) |
| 1740 | + category 3: 0.5 GiB (memory used by NCCL) |
| 1741 | +
|
| 1742 | + During profiling (peak): |
| 1743 | + category 1: 1 GiB |
| 1744 | + category 2: 4 GiB (peak activation tensors take 2 GiB) |
| 1745 | + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) |
| 1746 | +
|
| 1747 | + After profiling: |
| 1748 | + category 1: 1 GiB |
| 1749 | + category 2: 3 GiB (after garbage-collecting activation tensors) |
| 1750 | + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) |
| 1751 | +
|
| 1752 | + In this case, non-kv cache takes 5 GiB in total, including: |
| 1753 | + a. 2 GiB used by the model weights (category 2) |
| 1754 | + b. 2 GiB reserved for the peak activation tensors (category 2) |
| 1755 | + c. 1 GiB used by non-torch components (category 3) |
| 1756 | +
|
| 1757 | + The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. |
| 1758 | +
|
| 1759 | + The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). |
| 1760 | +
|
| 1761 | + (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), |
| 1762 | + subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. |
| 1763 | + """ # noqa |
| 1764 | + torch.cuda.reset_peak_memory_stats() |
| 1765 | + |
| 1766 | + result = MemoryProfilingResult() |
| 1767 | + |
| 1768 | + result.baseline_memory_in_bytes = baseline_memory_in_bytes |
| 1769 | + # the part of memory used for holding the model weights |
| 1770 | + result.weights_memory_in_bytes = weights_memory_in_bytes |
| 1771 | + |
| 1772 | + result.before_profile.measure() |
| 1773 | + |
| 1774 | + yield result |
| 1775 | + |
| 1776 | + gc.collect() |
| 1777 | + torch.cuda.empty_cache() |
| 1778 | + |
| 1779 | + result.after_profile.measure() |
| 1780 | + |
| 1781 | + diff = result.after_profile - result.before_profile |
| 1782 | + result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes |
| 1783 | + current_cuda_memory_bytes = torch.cuda.mem_get_info( |
| 1784 | + )[1] - torch.cuda.mem_get_info()[0] |
| 1785 | + result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa |
| 1786 | + result.profile_time = diff.timestamp |
| 1787 | + result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa |
0 commit comments