Support for DeepseekV32ForCausalLM with DeepSeek Sparse Attention (DSA)#21149
Support for DeepseekV32ForCausalLM with DeepSeek Sparse Attention (DSA)#21149fairydreaming wants to merge 33 commits intoggml-org:masterfrom
Conversation
…e attention). Needs manual change of add_bos_token to true in tokenizer_config.json before conversion.
…I think it's best not to quantize them.
…er implementation
…indexer implementation since the former fails for large tensors even when using CCCL.
… of llama_kv_cache and new llama_ik_cache (lightning indexer key cache). model : used new llama_kv_cache_dsa instead of modified llama_kv_cache with indexer keys in DeepseekV32ForCausalLM model : removed non-MLA path in DeepseekV32ForCausalLM
…lar to torch scatter_ operation.
…e can get rid of ggml_cast() calls in sparse attention implementation
…rm implementations
…orCausalLM-based models.
Hmmm, it should not crash, but fall back... |
I will check it again and report a bug if needed. |
This was already added in #21038 just not as an op, I suggest removing this one for obvious reasons and instead moving that implementation to an op, leaving backend implementation to others. |
Lol, so now I'm supposed to choose between @ikawrakow and @ggerganov Hadamard transform implementation? Thanks @CISC, very helpful of you. 😅 |
It's not a matter of choosing, I am genuinely being helpful, you know why this is contentious, and the implementation is already here and quite trivial, besides it's preferable (and in our policy) not to include backend changes in non-backend PRs. |
So I shouldn't include code from you-know-who in my PR because you-know-why? 😂 (btw I have Iwan permission to use this code in llama.cpp) If there is an official set of project-wide rules to follow regarding this (apparently highly radioactive) matter then it probably should be formalized in CONTRIBUTING.md file so that:
That's what I would call helpful in this matter. |
|
@fairydreaming Hi, I'm /u/digger412 on reddit, figured I'd migrate the convo here. I've got the electrical outlet installed last week and waiting on a new rack case to arrive to house everything. I think I can have 4 of the 6000 Pros up and running later today (with some hodgepodging and jank setup). If you can upload a quant that will fit into 384GiB of VRAM then I can try to run it, or I guess I could download the weights and convert it myself with your PR 🤔 Might take a few days but I will get to test this, I promise! |
That's the current state of affairs, yes ;) |
src/llama-ik-cache.h
Outdated
There was a problem hiding this comment.
I could be wrong, but I assume the IK cache (I assume you mean index K) is K-only cache; Can this be replaced by using this instead? #19067
There was a problem hiding this comment.
@ngxson Yes, but llama_kv_cache reads tensor dimensions, number of heads etc directly from hparams, so I can't simply instantiate another instance of the cache with different parameter values. I'm not satisfied with my current solution either, as it duplicates a lot of code.
Alternative solution would be to stuff the indexer key tensors in existing kv cache along with currently stored MLA latent representation + RoPE prefix tensors and make a view with an offset to read the cache. But that would make both MLA KV cache and indexer cache non-contiguous, not sure if that's a good idea.
There was a problem hiding this comment.
Hmm yeah I see. Because the indexer uses different size than the main attention block, duplicating the class is probably the cleanest way we can do for now.
In near future, we can also refactor KV cache, such that K and V are 2 separated llama-vec-cache. The "vector cache" can be reused across different types of cache, including index K, KV, iswa. CC @ggerganov for visibility
There was a problem hiding this comment.
We can decouple the kv cache implementation from the struct llama_model and struct llama_hparams. Would need to introduce struct llama_kv_cache_params and use that within the implementation without reference the model and it's hparams.
This way you should be able to instantiate two different KV caches with different llama_kv_cache_params. Would that work?
There was a problem hiding this comment.
@ggerganov Yes, that should work. But I did a quick check and llama_kv_cache.cpp currently uses:
hparams.has_kv(il)
hparams.is_mla()
hparams.is_n_embd_v_gqa_variable()
hparams.is_swa(il)
hparams.n_embd_head_k(il)
hparams.n_embd_head_v(il)
hparams.n_embd_k_gqa(il)
hparams.n_embd_v_gqa(il)
hparams.n_embd_v_gqa_max()
hparams.n_head_kv(il)
hparams.n_layer
hparams.n_layer_kv()
hparams.n_lora_kv
hparams.no_alloc
hparams.n_pos_per_embd()
hparams.n_rel_attn_bkts
hparams.n_rot(il)
hparams.rope_type
hparams.use_alibi
model.arch
model.dev_layer(il)
model.get_rope_factors()
model.get_rope_freq_base()
model.get_rope_freq_scale()
So I'm afraid it's not a trivial endeavor, but a major refactoring effort.
There was a problem hiding this comment.
As a first step, you can try to pass hparams separately from model and see if this will help deduplicate the llama_kv_cache/llama_ik_cache implementations.
So add a constructor:
llama_kv_cache(
const llama_model & model,
const llama_hparams hparams, // <--- custom hparams, can be overridden for indexing caches
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);This should be a small change and if it works, we can prepare a small refactor to support that.
There was a problem hiding this comment.
As a first step, you can try to pass
hparamsseparately frommodeland see if this will help deduplicate thellama_kv_cache/llama_ik_cacheimplementations.So add a constructor:
llama_kv_cache( const llama_model & model, const llama_hparams hparams, // <--- custom hparams, can be overridden for indexing caches ggml_type type_k, ggml_type type_v, bool v_trans, bool offload, bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse);This should be a small change and if it works, we can prepare a small refactor to support that.
@ggerganov So basically: make a copy of this huge pile of parameters, tweak some of them so that the second llama_kv_cache instance works as intended for caching indexer tensors and hope it won't break in the future? Horrible solution looking from the software engineering point of view, but matches the llama.cpp spirit well. Will try.
Probably you can get more context from PR number 19726 |
Great to hear from you! No need to hurry, I think I'd rather prefer some larger quant tested with all 8 cards, so that quantization won't affect the model cognitive performance. Also more VRAM = more concurrent requests. It's getting late today, so tomorrow I will create a discussion about testing the implementation and we can plan there in details. |
@ngxson Initially I did some reading on this and the origins, but I had more questions than answers afterwards and overall it just made me sad. |
|
I just looked at the CUDA code briefly. For the scatter, you should extend |
@am17an I've read #21038 in more detail today and this approach indeed may be applicable to my PR. I suppose I just have to wait until the dust in |
|
|
||
| const auto & kq_mask = inp->get_kq_mask(); | ||
|
|
||
| // prepare new kq mask - starts filled with -INFINITY | ||
| ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); | ||
|
|
||
| // modify it by unmasking tokens that are in top_k indices | ||
| ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0); | ||
|
|
||
| // combine with the original kq mask | ||
| kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); | ||
|
|
There was a problem hiding this comment.
I wonder, instead of masking the KV cache, wouldn't it be more efficient to extract a new K and KQ mask using ggml_get_rows(..., top_k) and perform the attention on those smaller tensors?
There was a problem hiding this comment.
I wonder, instead of masking the KV cache, wouldn't it be more efficient to extract a new K and KQ mask using
ggml_get_rows(..., top_k)and perform the attention on those smaller tensors?
@ggerganov I thought about this solution, but decided to go with the simplest possible one for now. By the way I think for KQ mask in this case we would need something like "for each row get elements that are in the corresponding top_k indices row". Do we have GGML OP like this?
There was a problem hiding this comment.
I see, it's not so simple as I thought.
Having ggml_scatter() seems useful to have anyway.
There was a problem hiding this comment.
GGML_OP_FILL can be extended to provide a list of indices to fill?
There was a problem hiding this comment.
I see, it's not so simple as I thought.
Having
ggml_scatter()seems useful to have anyway.
@ggerganov AFAIK torch gather works like I mentioned - gathers values from an axis based on specified indices (the way it's needed for KQ mask in this case), so it would be another new GGML OP (kind of symmetric to scatter that puts values on axis based on specified indices). My scatter is somewhat crippled anyway since it only accepts single scalar value, not tensor of values. So maybe it's a better idea to implement GGML_OP_GATHER to get KQ mask elements indicated by top_k and then use ggml_get_rows() to perform attention only on cached vectors that are in top_k indices.
I'm currently waiting for #21038 regarding the Hadamard transform implementation, so I can try to implement this solution in the meantime and see what comes of it.
There was a problem hiding this comment.
How large is the n_top_k typical?
There was a problem hiding this comment.
How large is the
n_top_ktypical?
@ggerganov DeepSeek V3.2 sets it to 2048. If n_kv is shorter than 2048 the result will be shorter too but for long sequences it would always take top 2048 cached k/v vectors.
There was a problem hiding this comment.
Do we know that the indexer improves prefill performance? I remember reading (and it is also obvious) that decoding (i.e. batch size 1) will be much faster with the indexer, but I think that for large batch sizes, we won't benefit much compared to simply doing the regular masked attention without re-gathering the indexed KV data. The reason is that at batch size 512 for example, each token "activating" 2048 KV cells would usually activate the entire cache anyway.
Just want to know if we should focus on a solution that works for small batches (e.g. less than 32), which might be much simpler.
There was a problem hiding this comment.
@ggerganov The DeepSeek V3.2 paper said:
for short-sequence prefilling, we specially implement a masked MHA mode to simulate DSA, which can achieve higher efficiency under short-context conditions
So I guess the optimal solution is a hybrid one and we need both (masked dense attention for short sequences and sparse attention for longer) - and that applies both for prefill and decode.
Regarding your remark about entire cache activation - I doubt the lightning indexer top k position selection would activate the entire cache, likely it's trained to attend only to most relevant positions and omit irrelevant ones, so if you had n_kv of 100k the activated n_top_k cache positions would be similar (largely overlapping) for all 512 ubatch tokens. But I can't support this with any data, this is just my intuition.
There was a problem hiding this comment.
Another idea that I think does not require any changes in ggml_get_rows() or reshaping flash attn arguments:
- Find all KV cache indices that at least one of the ubatch tokens attends to (this will be union of top k indices for a whole ubatch).
- Remove KQ mask columns that are not in this set (these columns will be all -INF anyway).
- Perform
ggml_get_rows()on K and V cache with indices from point 1 to get only cells contributing the the attention output for at least one ubatch token. - Do attention as usual.
This wastes more compute than my previous approach, but maybe would be good enough. Depends on the structure of top k indices for a whole ubatch.
I guess I will do some experiments first to see how top k indices look like when processing a whole ubatch with large KV cache.
|
I implemented @ggerganov idea to get rid of llama_ik_cache by creating another llama_kv_cache instance with tweaked hparams and it works - but during testing I started encountering llama-server crashes - already twice (sorry no detailed debug info, but looks like calling a method on deleted object or a corrupted pointer): They are "works OK for 2 hours and then suddenly dies" crashes and I'm not sure if it's my fault (could be) or some code from recent rebase, so I'm leaving it here in (unlikely) case someone knows what's going on. Back to debugging. Update: I have some leads, looks like llama_kv_cache_dsa_context * is being static_cast to llama_kv_cache_context * in some places and this wreaks havoc. My fault for being lazy and not implementing Update 2: I think it's fixed now, no more crashes observed so far. Also switched from ggml_hadamard() to rotation matrices multiplication, all looks good. |
…based on tweaked hparams.
…ase/no suffix was used for MLA part and _dsa/_ik were used for lightning indexer part, to make names more obvious I renamed _base/no suffix to _mla and _dsa/_ik to _lid.
…matrix multiplication. ggml : remove unused GGML_OP_HADAMARD
Overview
This PR adds support for DeepseekV32ForCausalLM (DeepSeek V3.2 Exp, DeepSeek V3.2, DeepSeek V3.2 Speciale) models. It contains implementation of the lightning indexer and DeepSeek Sparse Attention (DSA) - both implemented in the simplest possible way as a proof of concept. So far only CPU and CUDA backends are supported.
Due to the way it's currently implemented it doesn't improve long context performance yet, more work is needed for this.
Some GGUFs for testing are available here (-light models), I uploaded Q8_0/Q4_K_M quants, so you need over 700GB/400GB of RAM/VRAM to run them.
I also created a 16GB baby DeepSeek V3.2 GGUF for VRAM-deprived people. It outputs incoherent gibberish, but should be useful for testing and optimizing this implementation even with limited resources.
I really could use some help with verifying the implementation correctness. If you have large GPU cluster and can run some benchmarks to compare results with official reported benchmark results for DeepSeek V3.2 models then go for it. More details in #21183.
Fixes #16331, #20363
Additional information
Decisions I made when implementing this:
DEEPSEEK32was added (mostly a copy of existingGLM_DSAarch),GGML_OP_SCATTERthat works similar to torch scatter_ operation but is currently limited to setting tensor elements at specified indices to a given scalar value,GGML_OP_HADAMARDwith implementation borrowed from ik_llama.cpp (thx @ikawrakow),llama_kv_cache_dsaclass which aggregates the usualllama_kv_cachethat caches MLA latent representations (same as before for DeepSeek V3) and another newllama_ik_cacheclass (basically a copy of llama_kv_cache stripped of code related to V vector) that caches lightning indexer keys,Requirements
Due to limitations of the current CUDAggml_top_k()implementation NVIDIA CUDA CCCL library (version >3.2) and enabling GGML_CUDA_USE_CUB during CUDA backend compilation is needed, otherwise the CUDA implementation will crash for context sizes larger than (I think) 1024 tokens. I use it with CUDA 13.2 and CCCL 13.2.27.Bug in
ggml_top_k()is now fixed, fix is merged, so it should work even on 2.[89] CUDA without CCCL.Also if you want to convert the model by yourself, set
add_bos_tokento true intokenizer_config.jsonbefore the model conversion - this is needed for DeepSeek V3.2 and DeepSeek V3.2 Speciale. The conversion script has assert that checks this.Next Steps