Skip to content

Add GDS alternative for hierarchical kv cache#7896

Closed
didoteebin wants to merge 13 commits intosgl-project:mainfrom
didoteebin:main
Closed

Add GDS alternative for hierarchical kv cache#7896
didoteebin wants to merge 13 commits intosgl-project:mainfrom
didoteebin:main

Conversation

@didoteebin
Copy link
Copy Markdown

Motivation

Add GDS alternative for hierarchical kv cache

Add options for GPU kvcache direct store to SSD , potential TTFT benefit may gain.

Modifications

python/sglang/srt/mem_cache/memory_pool_host.py to add GDS support in MHATokenToKVPoolHost
python/sglang/srt/mem_cache/gds.py to add a GDS class based on NIXL lib
python/sglang/global_config.py to add a switch for user to decide whether to enable gds or not

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @didoteebin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the hierarchical KV cache management by integrating GPU Direct Storage (GDS) capabilities. The primary goal is to optimize the "Time To First Token" (TTFT) by enabling direct data transfers between GPU memory and SSDs for KV cache pages, thereby reducing reliance on traditional host memory for these operations. This change provides a configurable option for users to leverage GDS, particularly within the multi-head attention (MHA) token pool, to potentially improve performance for large language models.

Highlights

  • GPU Direct Storage (GDS) Integration: Implemented support for GPU Direct Storage (GDS) to enable direct KV cache transfers between GPU memory and SSDs, aiming to reduce Time To First Token (TTFT) by offloading data from host memory.
  • GDS Configuration: Introduced enable_gds and gds_path options in global_config.py, allowing users to enable and configure GDS functionality via environment variables.
  • Hierarchical KV Cache Enhancements: Modified MHATokenToKVPoolHost to conditionally leverage GDS for writing and loading KV cache pages across layers, providing an alternative to traditional host memory transfers.
  • Memory Pool Abstraction: Added new abstract methods (transfer, get_flat_data, get_flat_data_by_layer, assign_flat_data) to the HostKVCache base class, formalizing the interface for memory operations within the hierarchical KV cache system.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a GPU Direct Storage (GDS) alternative for the hierarchical KV cache. My review identified a critical issue with incorrect function arguments, a high-severity issue with improper handling of an environment variable, and another high-severity issue regarding a potential resource leak. Additionally, there are several medium-severity concerns related to code clarity and feature completeness.

Comment on lines +60 to +63
logger.info("xxxx")
self.size = int(host_size * 1e9 // self.size_per_token)
else:
logger.info(f"{self.device_pool.size},{host_to_device_ratio},{self.page_size},{host_size}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logger.info calls added in this __init__ method (e.g., lines 60, 63, 75, 91) appear to be for debugging purposes. They are either uninformative (e.g., "xxxx") or contain typos and extra characters (e.g., line 91 has "devcie" and "##...###"). Remove or change these logs to logger.debug to avoid cluttering production logs and improve code clarity.

Comment on lines +383 to +422
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)

def get_flat_data(self, indices):
return self.kv_buffer[:, indices]

def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]

def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data

def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)

def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The GDS logic has been added for MHATokenToKVPoolHost, but not for MLATokenToKVPoolHost. The new methods in MLATokenToKVPoolHost only implement the non-GDS path. This creates an inconsistency where GDS will not be used for MLA-style models even if enabled. Add the GDS implementation for MLATokenToKVPoolHost to ensure consistent behavior.

didoteebin and others added 3 commits July 9, 2025 17:21
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@msharmavikram
Copy link
Copy Markdown

@didoteebin can you enable gds via dynamo/nixl? This would enable not just gds but all backend.

This should help significantly compared to just gds that allows file system.

@xinranwang17
Copy link
Copy Markdown

Great job! Is it possible to share some benchmark data here? Also, I wonder is there any plan to use disk GDS access as L3 cache instead of replacing L2 cpu memory directly?

@didoteebin
Copy link
Copy Markdown
Author

@xinranwang17 Thanks , working on the benchmark to show TTFT benefit.
@xiezhq-hermann Question on the loading process .It seems that current logic is load whatever kvcache in cpu specified by host_indices to GPU area which specified by device_indices, there is no visibility of these kv cache will be eventually match or not when doing this loading process , am I right ?

@xiezhq-hermann
Copy link
Copy Markdown
Collaborator

@xinranwang17 Thanks , working on the benchmark to show TTFT benefit. @xiezhq-hermann Question on the loading process .It seems that current logic is load whatever kvcache in cpu specified by host_indices to GPU area which specified by device_indices, there is no visibility of these kv cache will be eventually match or not when doing this loading process , am I right ?

@didoteebin the loading would only start after matching the tokens if that's your question

@sohil-bst
Copy link
Copy Markdown

Hello, is there an update here? I'm looking for a solution to use my local ssd to extend kv cache. @didoteebin

@msharmavikram
Copy link
Copy Markdown

You can use SGL hicache with NIXL to get GDS support.

@sohil-bst
Copy link
Copy Markdown

I see both POSIX and GDS available for NIXL and plan to try both:

python3 -c "from nixl._api import nixl_agent, nixl_agent_config; \
    agent = nixl_agent('test', nixl_agent_config(backends=[])); \
    print(agent.get_plugin_list())"

_api.py:251 Initialized NIXL agent: test
['GDS', 'GDS_MT', 'GPUNETIO', 'GUSLI', 'OBJ', 'POSIX', 'UCX']

This PR led me to believe GDS was still a WIP though.

@xiezhq-hermann
Copy link
Copy Markdown
Collaborator

close for no update, but feel free to re-open it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants