@@ -1923,55 +1923,72 @@ def kill_process_tree(pid: int):
19231923@dataclass
19241924class MemorySnapshot :
19251925 """Memory snapshot."""
1926- torch_peak_in_bytes : int = 0
1927- torch_memory_in_bytes : int = 0
1926+ torch_peak : int = 0
1927+ cuda_memory : int = 0
1928+ torch_memory : int = 0
1929+ non_torch_memory : int = 0
19281930 timestamp : float = 0.0
1931+ auto_measure : bool = True
1932+
1933+ def __post_init__ (self ):
1934+ if self .auto_measure :
1935+ self .measure ()
19291936
19301937 def measure (self ):
1931- self .torch_peak_in_bytes = torch .cuda .max_memory_reserved ()
1938+ # we measure the torch peak memory usage via allocated_bytes,
1939+ # rather than `torch.cuda.memory_reserved()` .
1940+ # After `torch.cuda.reset_peak_memory_stats()`,
1941+ # `torch.cuda.memory_reserved()` will keep growing, and only shrink
1942+ # when we call `torch.cuda.empty_cache()` or OOM happens.
1943+ self .torch_peak = torch .cuda .memory_stats ().get (
1944+ "allocated_bytes.all.peak" , 0 )
1945+
1946+ self .cuda_memory = torch .cuda .mem_get_info (
1947+ )[1 ] - torch .cuda .mem_get_info ()[0 ]
1948+
19321949 # torch.cuda.memory_reserved() is how many bytes
19331950 # PyTorch gets from cuda (by calling cudaMalloc, etc.)
1934- self .torch_memory_in_bytes = torch .cuda .memory_reserved ()
1951+ # this is used to measure the non-torch memory usage
1952+ self .torch_memory = torch .cuda .memory_reserved ()
1953+
1954+ self .non_torch_memory = self .cuda_memory - self .torch_memory
19351955 self .timestamp = time .time ()
19361956
19371957 def __sub__ (self , other : "MemorySnapshot" ) -> "MemorySnapshot" :
1938- """support a - b"""
19391958 return MemorySnapshot (
1940- torch_peak_in_bytes = self .torch_peak_in_bytes -
1941- other .torch_peak_in_bytes ,
1942- torch_memory_in_bytes = self .torch_memory_in_bytes -
1943- other .torch_memory_in_bytes ,
1944- timestamp = self .timestamp - other .timestamp )
1959+ torch_peak = self .torch_peak - other .torch_peak ,
1960+ cuda_memory = self .cuda_memory - other .cuda_memory ,
1961+ torch_memory = self .torch_memory - other .torch_memory ,
1962+ non_torch_memory = self .non_torch_memory - other .non_torch_memory ,
1963+ timestamp = self .timestamp - other .timestamp ,
1964+ auto_measure = False ,
1965+ )
19451966
19461967
19471968@dataclass
19481969class MemoryProfilingResult :
1949- """Memory profiling result.
1950- """ # noqa
1951- baseline_memory_in_bytes : int = 0
1952- non_kv_cache_memory_in_bytes : int = 0
1953- torch_peak_increase_in_bytes : int = 0
1954- non_torch_increase_in_bytes : int = 0
1955- weights_memory_in_bytes : float = 0
1970+ """Memory profiling result. All numbers are in bytes.
1971+ """
1972+ non_kv_cache_memory : int = 0
1973+ torch_peak_increase : int = 0
1974+ non_torch_increase : int = 0
1975+ weights_memory : float = 0
1976+ before_create : MemorySnapshot = field ( default_factory = MemorySnapshot )
19561977 before_profile : MemorySnapshot = field (default_factory = MemorySnapshot )
19571978 after_profile : MemorySnapshot = field (default_factory = MemorySnapshot )
19581979 profile_time : float = 0.0
19591980
19601981
19611982@contextlib .contextmanager
19621983def memory_profiling (
1963- baseline_memory_in_bytes : int , weights_memory_in_bytes : int
1964- ) -> Generator [MemoryProfilingResult , None , None ]:
1984+ baseline_snapshot : MemorySnapshot ,
1985+ weights_memory : int ) -> Generator [MemoryProfilingResult , None , None ]:
19651986 """Memory profiling context manager.
1966- baseline_memory_in_bytes: memory used by all the components other than
1967- the current vLLM instance. It contains: memory used by other processes, memory
1968- used by another vLLM instance in the same process, etc. It is usually measured
1969- before the current vLLM instance initialize the device. And we assume it is
1970- constant during the profiling of the current vLLM instance.
1971- weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
1987+ baseline_snapshot: the memory snapshot before the current vLLM instance.
1988+ weights_memory: memory used by PyTorch when loading the model weights.
19721989 Note that, before loading the model weights, we also initialize the device
19731990 and distributed environment, which may consume some memory. This part is not
1974- included in the weights_memory_in_bytes because PyTorch does not control it.
1991+ included in the weights_memory because PyTorch does not control it.
19751992
19761993 The memory in one GPU can be classified into 3 categories:
19771994 1. memory used by anything other than the current vLLM instance.
@@ -2006,20 +2023,21 @@ def memory_profiling(
20062023 b. 2 GiB reserved for the peak activation tensors (category 2)
20072024 c. 1 GiB used by non-torch components (category 3)
20082025
2009- The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes `.
2026+ The memory used for loading weights (a.) is directly given from the argument `weights_memory `.
20102027
2011- The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
2028+ The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
20122029
2013- (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
2014- subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
2030+ The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
20152031 """ # noqa
2032+ gc .collect ()
2033+ torch .cuda .empty_cache ()
20162034 torch .cuda .reset_peak_memory_stats ()
20172035
20182036 result = MemoryProfilingResult ()
20192037
2020- result .baseline_memory_in_bytes = baseline_memory_in_bytes
2038+ result .before_create = baseline_snapshot
20212039 # the part of memory used for holding the model weights
2022- result .weights_memory_in_bytes = weights_memory_in_bytes
2040+ result .weights_memory = weights_memory
20232041
20242042 result .before_profile .measure ()
20252043
@@ -2030,13 +2048,12 @@ def memory_profiling(
20302048
20312049 result .after_profile .measure ()
20322050
2033- diff = result .after_profile - result .before_profile
2034- result .torch_peak_increase_in_bytes = diff .torch_peak_in_bytes
2035- current_cuda_memory_bytes = torch .cuda .mem_get_info (
2036- )[1 ] - torch .cuda .mem_get_info ()[0 ]
2037- result .non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff .torch_memory_in_bytes # noqa
2038- result .profile_time = diff .timestamp
2039- result .non_kv_cache_memory_in_bytes = result .non_torch_increase_in_bytes + result .torch_peak_increase_in_bytes + result .weights_memory_in_bytes # noqa
2051+ diff_profile = result .after_profile - result .before_profile
2052+ diff_from_create = result .after_profile - result .before_create
2053+ result .torch_peak_increase = diff_profile .torch_peak
2054+ result .non_torch_increase = diff_from_create .non_torch_memory
2055+ result .profile_time = diff_profile .timestamp
2056+ result .non_kv_cache_memory = result .non_torch_increase + result .torch_peak_increase + result .weights_memory # noqa
20402057
20412058
20422059# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
0 commit comments