Skip to content

Commit 8879ec4

Browse files
brb-nvchzblych
authored andcommitted
[https://nvbugs/5501557][fix] Fix out-of-bounds vector access for model with multiple layer types (#7636)
Signed-off-by: Balaram Buddharaju <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent ab915fb commit 8879ec4

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,40 @@ using tensorrt_llm::batch_manager::CacheTransceiverFactory;
8585
namespace tensorrt_llm::batch_manager
8686
{
8787

88+
std::map<SizeType32, SizeType32> TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
89+
ModelConfig const& modelConfig, WorldConfig const& worldConfig,
90+
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
91+
{
92+
// These are the number of attention layers on this PP rank.
93+
auto const numLocalAttnLayers
94+
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
95+
// These are the number of attention layers on all previous PP ranks.
96+
auto const numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION,
97+
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
98+
// Use global ranks of attention layers to lookup from maxAttentionWindowVec.
99+
auto const startAttnLayerId = numLowerRankAttnLayers;
100+
auto const endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
101+
auto const numNonUniqueWindowSizes = static_cast<SizeType32>(maxAttentionWindowVec.size());
102+
std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
103+
for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
104+
{
105+
// maxAttentionWindowVec may or may not be stretched to the length of numLayers yet.
106+
// If not stretched yet, we cycle through the window sizes.
107+
auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes);
108+
uniqueWindowSizeToLayers[windowSize].push_back(layerIdx);
109+
}
110+
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
111+
for (auto const& [windowSize, globalLayerIds] : uniqueWindowSizeToLayers)
112+
{
113+
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
114+
modelConfig, globalLayerIds, isCrossAttention, kvFactor);
115+
auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
116+
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
117+
}
118+
119+
return cacheSizeBytesPerTokenPerWindow;
120+
};
121+
88122
bool TrtGptModelInflightBatching::executorConfigIsValid(
89123
ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig)
90124
{
@@ -266,32 +300,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
266300
}
267301
if (mModelConfig.isTransformerBased() && modelConfig.isKVCacheEnabled())
268302
{
269-
270-
auto calculateCacheSizePerToken
271-
= [](ModelConfig const& modelConfig, WorldConfig const& worldConfig,
272-
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
273-
{
274-
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange(
275-
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention);
276-
auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
277-
auto windowSizeLayers
278-
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, modelConfig.getNbLayers());
279-
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
280-
for (auto const& [windowSize, managedLayers] : windowSizeLayers)
281-
{
282-
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
283-
modelConfig, managedLayers, isCrossAttention, kvFactor);
284-
auto const cacheSizeBytesPerToken
285-
= cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
286-
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
287-
}
288-
289-
return cacheSizeBytesPerTokenPerWindow;
290-
};
291303
auto cacheTransceiverConfig
292304
= executorConfig.getCacheTransceiverConfig().value_or(executor::CacheTransceiverConfig());
293305

294-
auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerToken(
306+
auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerTokenForDisagg(
295307
mModelConfig, mWorldConfig, getMaxAttentionWindowVec(), mModelConfig.useCrossAttention(), 2);
296308
auto cacheTransPreAllocaSize = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize(
297309
cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,19 @@ class TrtGptModelInflightBatching : public TrtGptModel
152152

153153
~TrtGptModelInflightBatching() override;
154154

155+
/// @brief Calculate the cache size per token for the disaggregated serving.
156+
/// @param modelConfig Model configuration.
157+
/// @param worldConfig World configuration.
158+
/// @param maxAttentionWindowVec Maximum attention window vector. (may have fewer elements than numLayers, in which
159+
/// case it cycles)
160+
/// @param isCrossAttention Whether the attention is cross attention.
161+
/// @param kvFactor KV factor.
162+
/// @return Cache size per token for the disaggregated layers. Note that window size is not included in the result
163+
/// here.
164+
[[nodiscard]] static std::map<SizeType32, SizeType32> calculateCacheSizePerTokenForDisagg(
165+
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
166+
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor);
167+
155168
void terminateRequest(LlmRequestPtr const& llmRequest, bool pause = false) override;
156169

157170
/// @brief Terminate request in the next forwardSync call that includes the request.

cpp/tests/unit_tests/executor/executorTestSmall.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <random>
1313
#include <tuple>
14+
#include <unordered_map>
1415

1516
namespace tensorrt_llm::testing
1617
{
@@ -201,4 +202,88 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator,
201202
return nameStringStream.str();
202203
});
203204

205+
// Helper function to test calculateCacheSizePerToken with given parameters.
206+
std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper(
207+
std::vector<runtime::SizeType32> const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2,
208+
runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4,
209+
runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512,
210+
bool isCrossAttention = false)
211+
{
212+
// Create minimal ModelConfig for testing.
213+
auto modelConfig = runtime::ModelConfig(
214+
vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT);
215+
modelConfig.useGptAttentionPlugin(true);
216+
modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt);
217+
modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED);
218+
219+
auto const worldConfig = runtime::WorldConfig();
220+
221+
return batch_manager::TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
222+
modelConfig, worldConfig, maxAttentionWindowVec, isCrossAttention, kvFactor);
223+
}
224+
225+
// Test for TrtGptModelInflightBatching::calculateCacheSizePerToken function with different layer types.
226+
TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
227+
{
228+
// Common parameters.
229+
constexpr runtime::SizeType32 nbLayers = 5;
230+
constexpr runtime::SizeType32 hiddenSize = 512;
231+
constexpr runtime::SizeType32 kvFactor = 2;
232+
constexpr runtime::SizeType32 vocabSize = 32;
233+
constexpr runtime::SizeType32 nbHeads = 8;
234+
// Test case 1: Single attention window size - attention layers only.
235+
{
236+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
237+
constexpr runtime::SizeType32 nbAttentionLayers = 5;
238+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
239+
constexpr runtime::SizeType32 nbRnnLayers = 0;
240+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
241+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
242+
EXPECT_EQ(result.size(), 1);
243+
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
244+
}
245+
246+
// Test case 2: Multiple attention window sizes - attention layers only.
247+
{
248+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
249+
constexpr runtime::SizeType32 nbAttentionLayers = 5;
250+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
251+
constexpr runtime::SizeType32 nbRnnLayers = 0;
252+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
253+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
254+
EXPECT_EQ(result.size(), 2);
255+
auto const nbAttentionLayersIn128Window = 3;
256+
auto const nbAttentionLayersIn256Window = 2;
257+
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
258+
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
259+
}
260+
261+
// Test case 3: Single attention window size - attention and rnn layers.
262+
{
263+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
264+
constexpr runtime::SizeType32 nbAttentionLayers = 3;
265+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
266+
constexpr runtime::SizeType32 nbRnnLayers = 2;
267+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
268+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
269+
EXPECT_EQ(result.size(), 1);
270+
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
271+
}
272+
273+
// Test case 4: Multiple attention window sizes - attention and rnn layers.
274+
{
275+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
276+
constexpr runtime::SizeType32 nbAttentionLayers = 3;
277+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
278+
constexpr runtime::SizeType32 nbRnnLayers = 2;
279+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
280+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
281+
EXPECT_EQ(result.size(), 2);
282+
auto const nbAttentionLayersIn128Window = 2;
283+
auto const nbAttentionLayersIn256Window = 1;
284+
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
285+
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
286+
}
287+
}
288+
204289
} // namespace tensorrt_llm::testing

0 commit comments

Comments
 (0)