Faster long context TG on CUDA for GLM-4.5/4.6/4.7/AIR#1183
Conversation
|
This makes no sense. It has improved on mixed CPU+GPU inference from 5.54 tokens/s (in this comment) to 6.20 t/s for same conditions. Output has changed, but seems very coherent. I need to experiment a bit more with other GLM models, but this looks very good. Thanks for this upgrade! |
uh oh? Have you measured the PPL? |
When you split FA into two parts, that modifies the order in which multiply-adds are accumulated, and that modifies the accumulated result due to finite floating point arithmetic precision. So, not getting the exact same sequence of tokens (using the same random number seed) is expected. But yes, I did indeed verify that PPL is the same within numerical round-off precision. |
|
I unfortunately seem to be getting worse performance with this PR below 32k context. I haven't tested above this context size (perhaps performance is better past 32k), but the results aren't promising otherwise. Maybe it has something to do with my specific override tensor configuration, or my hardware? OS: Arch Linux Command structure: Commit 2a7cc09
PR
|
The GLM4-MoE models are notorious for strong inference performance decline with increasing context length, which is due to the unfortunate GQA ratio of 12.
This PR remedies the situation to some extent. It uses a similar technique as in PR #1182 to improve long-context TG performance on CUDA for the GLM4-MoE series of models. But unlike #1182, where there is a single KV head and hence simple views are sufficient to split the FA computation in two parts, here we have 8 KV heads (fewer with split mode
graph), so one needs to incarnate two contiguous copies of theQtensor to obtain the required splits.Caveat: the PR does not improve the performance when quantized KV cache is used. Implementing the optimization for quantized KV cache is a bit more involved, so it is left for a follow up PR.
The following graph shows TG performance as a function of context length for GLM-4.5-AIR-IQ1_KT on a 4x3090 system (but for the 2x3090 data points only 2 GPUs are selected). Mani branch and PR coincide up to a certain point because the split is only done above a given threshold that depends on the number of participating GPUs. For 8 GPUs the split only kicks above 64k tokens, so is not shown here. For split mode
layerwe gain ~30% at context of 64k. For 2 GPUs and split mode graph, where we can only go up to context of 32k tokens with this model, the gain is about 13%. For 4 GPUs and split modegraphspeedup is ~10% at 64k tokens.