@@ -269,16 +269,28 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
269269 = [](ModelConfig const & modelConfig, WorldConfig const & worldConfig,
270270 std::vector<SizeType32> const & maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
271271 {
272- auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange (
273- worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank (), isCrossAttention);
274- auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
275- auto windowSizeLayers
276- = BaseKVCacheManager::groupLayersByWindowSize (maxAttentionWindowVec, modelConfig.getNbLayers ());
272+ // These are the number of attention layers on this PP rank.
273+ const auto numLocalAttnLayers = modelConfig.getNbAttentionLayers (worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank ());
274+ // These are the number of attention layers on all previous PP ranks.
275+ const auto numLowerRankAttnLayers = modelConfig.countLowerRankLayers (
276+ ModelConfig::LayerType::kATTENTION , worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank ());
277+ // Use global ranks of attention layers to lookup from maxAttentionWindowVec.
278+ const auto startAttnLayerId = numLowerRankAttnLayers;
279+ const auto endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
280+ auto const numNonUniqueWindowSizes = static_cast <SizeType32>(maxAttentionWindowVec.size ());
281+ std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
282+ for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
283+ {
284+ // maxAttentionWindowVec may or may not be stretched to the length of numLayers yet.
285+ // If not stretched yet, we cycle through the window sizes.
286+ auto const windowSize = maxAttentionWindowVec.at (layerIdx % numNonUniqueWindowSizes);
287+ uniqueWindowSizeToLayers[windowSize].push_back (layerIdx);
288+ }
277289 std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
278- for (auto const & [windowSize, managedLayers ] : windowSizeLayers )
290+ for (auto const & [windowSize, globalLayerIds ] : uniqueWindowSizeToLayers )
279291 {
280292 auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize (
281- modelConfig, managedLayers , isCrossAttention, kvFactor);
293+ modelConfig, globalLayerIds , isCrossAttention, kvFactor);
282294 auto const cacheSizeBytesPerToken
283295 = cacheSizePerToken * BufferDataType (modelConfig.getKvDataType ()).getSize ();
284296 cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
0 commit comments