1919#include < queue>
2020#include < unordered_map>
2121
22+ #include " allocator.h"
2223#include " environment.h"
2324#include " sampling_params.h"
2425
@@ -67,7 +68,7 @@ class SequenceIDManager {
6768// The SequenceMeta is one sequence of batch inputs and includes the generated tokens.
6869class SequenceMeta {
6970public:
70- SequenceMeta (std::vector<int32_t > &_promptTokens)
71+ SequenceMeta (const std::vector<int32_t > &_promptTokens)
7172 : sequenceID(SequenceIDManager::getInstance().createSequenceID())
7273 , inputSeqLen(_promptTokens.size())
7374 , pastSeqLen(0 )
@@ -81,6 +82,16 @@ class SequenceMeta {
8182 , promptTokens(_inputSeqLen, 0 )
8283 , step(0 ) {}
8384
85+ SequenceMeta (int32_t _sequenceID, const std::vector<int32_t > &_promptTokens)
86+ : sequenceID(_sequenceID)
87+ , inputSeqLen(_promptTokens.size())
88+ , pastSeqLen(0 )
89+ , promptTokens(_promptTokens)
90+ , step(0 ) {}
91+
92+ SequenceMeta (int32_t _sequenceID, int32_t _inputSeqLen)
93+ : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0 ), promptTokens(_inputSeqLen, 0 ), step(0 ) {}
94+
8495 ~SequenceMeta () {}
8596
8697 int32_t getSequenceID () const { return sequenceID; }
@@ -175,7 +186,8 @@ class SequenceGroupMeta {
175186 groupID = sequences[0 ].getSequenceID ();
176187 }
177188
178- SequenceGroupMeta (std::vector<int32_t > &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) {
189+ SequenceGroupMeta (const std::vector<int32_t > &_inputTokens, SamplingMeta &samplingMeta_)
190+ : samplingMeta(samplingMeta_) {
179191 sequences.reserve (samplingMeta.config .numBeams );
180192 for (int i = 0 ; i < samplingMeta.config .numBeams ; ++i) {
181193 sequences.emplace_back (SequenceMeta (_inputTokens));
@@ -191,7 +203,7 @@ class SequenceGroupMeta {
191203 groupID = sequences[0 ].getSequenceID ();
192204 }
193205
194- SequenceGroupMeta (std::vector<int32_t > &_inputTokens) {
206+ SequenceGroupMeta (const std::vector<int32_t > &_inputTokens) {
195207 sequences.reserve (samplingMeta.config .numBeams );
196208 for (int i = 0 ; i < samplingMeta.config .numBeams ; ++i) {
197209 sequences.emplace_back (SequenceMeta (_inputTokens));
@@ -207,6 +219,40 @@ class SequenceGroupMeta {
207219 groupID = sequences[0 ].getSequenceID ();
208220 }
209221
222+ SequenceGroupMeta (int32_t _sequenceID, const std::vector<int32_t > &_inputTokens, SamplingMeta &samplingMeta_)
223+ : samplingMeta(samplingMeta_) {
224+ sequences.reserve (samplingMeta.config .numBeams );
225+ for (int i = 0 ; i < samplingMeta.config .numBeams ; ++i) {
226+ sequences.emplace_back (SequenceMeta (_sequenceID, _inputTokens));
227+ }
228+ groupID = sequences[0 ].getSequenceID ();
229+ }
230+
231+ SequenceGroupMeta (int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_)
232+ : samplingMeta(samplingMeta_) {
233+ sequences.reserve (samplingMeta.config .numBeams );
234+ for (int i = 0 ; i < samplingMeta.config .numBeams ; ++i) {
235+ sequences.emplace_back (SequenceMeta (_sequenceID, _inputSeqLen));
236+ }
237+ groupID = sequences[0 ].getSequenceID ();
238+ }
239+
240+ SequenceGroupMeta (int32_t _sequenceID, const std::vector<int32_t > &_inputTokens) {
241+ sequences.reserve (samplingMeta.config .numBeams );
242+ for (int i = 0 ; i < samplingMeta.config .numBeams ; ++i) {
243+ sequences.emplace_back (SequenceMeta (_sequenceID, _inputTokens));
244+ }
245+ groupID = sequences[0 ].getSequenceID ();
246+ }
247+
248+ SequenceGroupMeta (int32_t _sequenceID, int32_t _inputSeqLen) {
249+ sequences.reserve (samplingMeta.config .numBeams );
250+ for (int i = 0 ; i < samplingMeta.config .numBeams ; ++i) {
251+ sequences.emplace_back (SequenceMeta (_sequenceID, _inputSeqLen));
252+ }
253+ groupID = sequences[0 ].getSequenceID ();
254+ }
255+
210256 int32_t getGroupID () { return groupID; }
211257
212258 int32_t getGroupSize () { return samplingMeta.config .numBeams ; }
@@ -272,6 +318,31 @@ class SequencePool {
272318 return group;
273319 }
274320
321+ SequenceGroupMeta *newGroupMeta (
322+ int32_t sequenceID, std::vector<int32_t > &inputTokens, SamplingMeta &samplingMeta_) {
323+ auto *group = new SequenceGroupMeta (sequenceID, inputTokens, samplingMeta_);
324+ this ->add (group);
325+ return group;
326+ }
327+
328+ SequenceGroupMeta *newGroupMeta (int32_t sequenceID, int32_t inputSeqLen, SamplingMeta &samplingMeta_) {
329+ auto *group = new SequenceGroupMeta (sequenceID, inputSeqLen, samplingMeta_);
330+ this ->add (group);
331+ return group;
332+ }
333+
334+ SequenceGroupMeta *newGroupMeta (int32_t sequenceID, std::vector<int32_t > &inputTokens) {
335+ auto *group = new SequenceGroupMeta (sequenceID, inputTokens);
336+ this ->add (group);
337+ return group;
338+ }
339+
340+ SequenceGroupMeta *newGroupMeta (int32_t sequenceID, int32_t inputSeqLen) {
341+ auto *group = new SequenceGroupMeta (sequenceID, inputSeqLen);
342+ this ->add (group);
343+ return group;
344+ }
345+
275346 bool add (SequenceGroupMeta *sequenceGroup, bool force = false ) {
276347 int32_t groupID = sequenceGroup->getGroupID ();
277348 bool isSuccess = false ;
0 commit comments