-
Notifications
You must be signed in to change notification settings - Fork 454
Support sleep, wake_up and load_weights for Omni Diffusion #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+343
−6
Merged
Changes from 2 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
3070c9e
Fix pre-commit
knlnguyen1802 db68abf
Fix code
knlnguyen1802 28b7195
Add test
knlnguyen1802 72a15d3
Fix name and add test
knlnguyen1802 728162b
Fix pre-commit
knlnguyen1802 03321f0
Fix test
knlnguyen1802 adf821e
Rebase
knlnguyen1802 c36dbd8
Merge branch 'main' into diffusion_support
knlnguyen1802 5260232
Add test to pipeline
knlnguyen1802 6b966d9
Merge branch 'diffusion_support' of https://github.com/knlnguyen1802/…
knlnguyen1802 9f77c12
Merge branch 'main' into diffusion_support
knlnguyen1802 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
| import multiprocessing as mp | ||
| import os | ||
| import time | ||
| from collections.abc import Iterable | ||
| from contextlib import AbstractContextManager, nullcontext | ||
|
|
||
| import torch | ||
| import zmq | ||
|
|
@@ -42,7 +44,7 @@ def __init__( | |
| self.rank = rank | ||
| self.od_config = od_config | ||
| self.pipeline = None | ||
|
|
||
| self._sleep_saved_buffers: dict[str, torch.Tensor] = {} | ||
| self.init_device_and_model() | ||
|
|
||
| def init_device_and_model(self) -> None: | ||
|
|
@@ -71,11 +73,12 @@ def init_device_and_model(self) -> None: | |
| load_config = LoadConfig() | ||
| model_loader = DiffusersPipelineLoader(load_config) | ||
| time_before_load = time.perf_counter() | ||
| with DeviceMemoryProfiler() as m: | ||
| self.pipeline = model_loader.load_model( | ||
| od_config=self.od_config, | ||
| load_device=f"cuda:{rank}", | ||
| ) | ||
| with self._maybe_get_memory_pool_context(tag="weights"): | ||
| with DeviceMemoryProfiler() as m: | ||
| self.pipeline = model_loader.load_model( | ||
| od_config=self.od_config, | ||
| load_device=f"cuda:{rank}", | ||
| ) | ||
| time_after_load = time.perf_counter() | ||
|
|
||
| logger.info( | ||
|
|
@@ -107,6 +110,58 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi | |
| output = self.pipeline.forward(req) | ||
| return output | ||
|
|
||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||
| return self.pipeline.loaded_weights(weights) | ||
|
|
||
| def sleep(self, level: int = 1) -> bool: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to add a function in engine to call this and expose an interface to user?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from vllm.device_allocator.cumem import CuMemAllocator | ||
|
|
||
| free_bytes_before_sleep = torch.cuda.mem_get_info()[0] | ||
|
|
||
| # Save the buffers before level 2 sleep | ||
| if level == 2: | ||
| model = self.pipeline | ||
| self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} | ||
|
|
||
| allocator = CuMemAllocator.get_instance() | ||
| allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) | ||
| free_bytes_after_sleep, total = torch.cuda.mem_get_info() | ||
| freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep | ||
| used_bytes = total - free_bytes_after_sleep | ||
| assert freed_bytes >= 0, "Memory usage increased after sleeping." | ||
| logger.info( | ||
| "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", | ||
| freed_bytes / GiB_bytes, | ||
| used_bytes / GiB_bytes, | ||
| ) | ||
| return True | ||
|
|
||
| def wake_up(self, tags: list[str] | None = None) -> bool: | ||
| from vllm.device_allocator.cumem import CuMemAllocator | ||
|
|
||
| allocator = CuMemAllocator.get_instance() | ||
| allocator.wake_up(tags) | ||
|
|
||
| # Restore the buffers after level 2 sleep | ||
| if len(self._sleep_saved_buffers): | ||
| model = self.pipeline | ||
| for name, buffer in model.named_buffers(): | ||
| if name in self._sleep_saved_buffers: | ||
| buffer.data.copy_(self._sleep_saved_buffers[name].data) | ||
| self._sleep_saved_buffers = {} | ||
| return True | ||
|
|
||
| def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: | ||
| if self.od_config.enable_sleep_mode: | ||
| from vllm.device_allocator.cumem import CuMemAllocator | ||
|
|
||
| allocator = CuMemAllocator.get_instance() | ||
| if tag == "weights": | ||
| assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." | ||
| return allocator.use_memory_pool(tag=tag) | ||
| else: | ||
| return nullcontext() | ||
|
|
||
| def shutdown(self) -> None: | ||
| if torch.distributed.is_initialized(): | ||
| try: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.