Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,16 @@ extern "C" {
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);

// GPU-resident checkpoint for speculative decoding (fast save/restore)
// Returns false if the memory implementation does not support GPU checkpoints
LLAMA_API bool llama_memory_checkpoint_save(llama_memory_t mem, llama_seq_id seq_id);

LLAMA_API bool llama_memory_checkpoint_restore(llama_memory_t mem, llama_seq_id seq_id);

LLAMA_API void llama_memory_checkpoint_delete(llama_memory_t mem);

LLAMA_API bool llama_memory_checkpoint_supported(llama_memory_t mem);

//
// State / sessions
//
Expand Down
16 changes: 16 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3292,6 +3292,22 @@ bool llama_memory_can_shift(llama_memory_t mem) {
return mem->get_can_shift();
}

bool llama_memory_checkpoint_save(llama_memory_t mem, llama_seq_id seq_id) {
return mem->checkpoint_save(seq_id);
}

bool llama_memory_checkpoint_restore(llama_memory_t mem, llama_seq_id seq_id) {
return mem->checkpoint_restore(seq_id);
}

void llama_memory_checkpoint_delete(llama_memory_t mem) {
mem->checkpoint_delete();
}

bool llama_memory_checkpoint_supported(llama_memory_t mem) {
return mem->checkpoint_supported();
}

// llama state API

// deprecated
Expand Down
146 changes: 146 additions & 0 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,152 @@ llama_memory_recurrent::llama_memory_recurrent(
}
}

bool llama_memory_recurrent::checkpoint_alloc_shadows() {
if (ckpt.allocated) {
return true;
}

const int32_t n_layer = hparams.n_layer;

ckpt.r_l_shadow.resize(n_layer, nullptr);
ckpt.s_l_shadow.resize(n_layer, nullptr);

// Mirror the primary tensor allocation pattern:
// group tensors by buffer type, one ggml_context per buffer type
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};

std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;

auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/size_t(2u * n_layer * ggml_tensor_overhead()),
/*.mem_buffer =*/NULL,
/*.no_alloc =*/true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
return nullptr;
}
ctx_map.emplace(buft, ctx);
return ctx;
}
return it->second.get();
};

for (int i = 0; i < n_layer; i++) {
if (r_l[i] == nullptr) {
continue;
}

ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(r_l[i]->buffer);

ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to create ggml context for shadow tensors\n", __func__);
return false;
}

ggml_tensor * r_shadow = ggml_dup_tensor(ctx, r_l[i]);
ggml_tensor * s_shadow = ggml_dup_tensor(ctx, s_l[i]);
ggml_format_name(r_shadow, "cache_r_l%d_shadow", i);
ggml_format_name(s_shadow, "cache_s_l%d_shadow", i);

ckpt.r_l_shadow[i] = r_shadow;
ckpt.s_l_shadow[i] = s_shadow;
}

// Allocate buffers (same pattern as constructor)
for (auto & [buft, ctx] : ctx_map) {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for shadow tensors\n", __func__);
return false;
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s shadow buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf),
ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
ckpt.shadow_bufs.emplace_back(std::move(ctx), buf);
}

ckpt.allocated = true;
return true;
}

bool llama_memory_recurrent::checkpoint_supported() const {
for (const auto * r : r_l) {
if (r != nullptr) {
return true;
}
}
return false;
}

bool llama_memory_recurrent::checkpoint_save(llama_seq_id seq_id) {
GGML_UNUSED(seq_id); // single-checkpoint, full-tensor copy — seq_id for future use

if (!checkpoint_alloc_shadows()) {
return false;
}

const int32_t n_layer = hparams.n_layer;

// 1. Snapshot cell metadata (small, host-side, fast)
ckpt.cells_snapshot = cells;
ckpt.head_snapshot = head;
ckpt.used_snapshot = used;
ckpt.rs_z_snapshot = rs_z;

// 2. Copy tensors: primary → shadow (GPU-to-GPU when on same device)
for (int il = 0; il < n_layer; ++il) {
if (r_l[il] == nullptr) {
continue;
}
ggml_backend_tensor_copy(r_l[il], ckpt.r_l_shadow[il]);
ggml_backend_tensor_copy(s_l[il], ckpt.s_l_shadow[il]);
}

ckpt.saved = true;
return true;
}

bool llama_memory_recurrent::checkpoint_restore(llama_seq_id seq_id) {
GGML_UNUSED(seq_id); // single-checkpoint, full-tensor copy

if (!ckpt.saved) {
LLAMA_LOG_ERROR("%s: no checkpoint saved\n", __func__);
return false;
}

const int32_t n_layer = hparams.n_layer;

// 1. Restore cell metadata
cells = ckpt.cells_snapshot;
head = ckpt.head_snapshot;
used = ckpt.used_snapshot;
rs_z = ckpt.rs_z_snapshot;

// 2. Copy tensors: shadow → primary (GPU-to-GPU)
for (int il = 0; il < n_layer; ++il) {
if (r_l[il] == nullptr) {
continue;
}
ggml_backend_tensor_copy(ckpt.r_l_shadow[il], r_l[il]);
ggml_backend_tensor_copy(ckpt.s_l_shadow[il], s_l[il]);
}

return true;
}

void llama_memory_recurrent::checkpoint_delete() {
ckpt.saved = false;
}

void llama_memory_recurrent::clear(bool data) {
for (int32_t i = 0; i < (int32_t) size; ++i) {
cells[i].pos = -1;
Expand Down
29 changes: 29 additions & 0 deletions src/llama-memory-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class llama_memory_recurrent : public llama_memory_i {
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

// GPU-to-GPU checkpoint: fast save/restore for speculative decoding
bool checkpoint_save(llama_seq_id seq_id) override;
bool checkpoint_restore(llama_seq_id seq_id) override;
void checkpoint_delete() override;
bool checkpoint_supported() const override;

uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
Expand Down Expand Up @@ -112,6 +118,29 @@ class llama_memory_recurrent : public llama_memory_i {
// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;

// GPU-resident checkpoint: shadow tensors + cell metadata snapshot
struct gpu_checkpoint {
// cell metadata snapshot (small, host-side)
std::vector<mem_cell> cells_snapshot;
uint32_t head_snapshot = 0;
uint32_t used_snapshot = 0;
int32_t rs_z_snapshot = -1;

// shadow tensors (same device as primaries)
std::vector<ggml_tensor *> r_l_shadow;
std::vector<ggml_tensor *> s_l_shadow;

// ggml contexts + buffers for shadow tensors
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> shadow_bufs;

bool allocated = false;
bool saved = false; // true after a successful save
};

gpu_checkpoint ckpt;

bool checkpoint_alloc_shadows();

size_t total_size() const;

size_t size_r_bytes() const;
Expand Down
12 changes: 12 additions & 0 deletions src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ struct llama_memory_i {

virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;

//
// GPU-resident checkpoint for speculative decoding (optional, default no-op)
//

virtual bool checkpoint_save(llama_seq_id /*seq_id*/) { return false; }

virtual bool checkpoint_restore(llama_seq_id /*seq_id*/) { return false; }

virtual void checkpoint_delete() {}

virtual bool checkpoint_supported() const { return false; }
};

using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
Loading