diff --git a/examples/cpp/example.cpp b/examples/cpp/example.cpp index 1f0683cf..3ec646b3 100644 --- a/examples/cpp/example.cpp +++ b/examples/cpp/example.cpp @@ -218,9 +218,9 @@ int main(int argc, char **argv) { args.add("dtype", 'd', "weight data type", false, "fp16", cmdline::oneof("fp16", "bf16", "int8", "bf16_fp16", "bf16_int8")); args.add("input_len", 'l', "input token size", false, -1); - args.add("output_len", '\0', "max tokens can generate excluded input.", false, 100, cmdline::range(1, 4096)); + args.add("output_len", '\0', "max tokens can generate excluded input.", false, 100, cmdline::range(1, 8192)); args.add("num_beams", 'n', "number of beam size.", false, 1, cmdline::range(1, 32)); - args.add("batch_size", 'b', "batch size.", false, 1, cmdline::range(1, 32)); + args.add("batch_size", 'b', "batch size.", false, 1, cmdline::range(1, 512)); args.add("loop", '\0', "number of loop.", false, 10); args.add("topK", '\0', "number of highest probability tokens to keep for top-k-filtering.", false, 50); args.add("temperature", '\0', "value used to modulate the next token probabilities.", false, 1.0); diff --git a/include/models.h b/include/models.h index e66f076a..ff2faec6 100644 --- a/include/models.h +++ b/include/models.h @@ -33,6 +33,8 @@ class Model { bool doEarlyStopping_ = false, int eosTokenId_ = -1, int padTokenId_ = -1, bool doSample_ = false, float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0); + void config(SearcherConfig &config_); + bool isDone(); std::vector generate(); diff --git a/src/models/models.cpp b/src/models/models.cpp index 2d4563b4..52f99611 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -19,15 +19,32 @@ #include #include "INIReader.h" +#include "baichuan.h" #include "chatglm.h" #include "chatglm2.h" #include "hybrid_model.h" #include "llama.h" -#include "baichuan.h" #include "opt_decoder.h" #include "searcher.h" namespace xft { +enum class GenerationMode { GREEDY_SEARCH, BEAM_SEARCH, SAMPLE }; + +GenerationMode getGenerationMode(SearcherConfig &config_) { + if (config_.numBeams == 1) { + if (config_.doSample) { + return GenerationMode::SAMPLE; + } else { + return GenerationMode::GREEDY_SEARCH; + } + } else if (config_.numBeams > 1) { + return GenerationMode::BEAM_SEARCH; + } else { + printf("numBeams should greater than or equal to 1.\n"); + exit(-1); + } +} + Model::~Model() { exitSlaves(); if (decoder != nullptr) { delete decoder; } @@ -84,6 +101,18 @@ void Model::config(int maxLen_, int numBeams_, int numBeamHypsToKeep_, float len createSearcher(configuration); } +void Model::config(SearcherConfig &config_) { + isNewInput = true; + if (decoder->getRank() == 0) { configuration = config_; } + Messenger &messenger = decoder->getMessenger(); + messenger.broadcast((int *)&configuration, sizeof(SearcherConfig) / sizeof(int)); + + // Slaves get exit flags and exit directly + if (decoder->getRank() > 0 && configuration.numBeams == 0) { exit(0); } + + createSearcher(configuration); +} + bool Model::isDone() { if (searcher == nullptr || inputIds.empty()) { printf("Please set input and config first.\n"); @@ -112,17 +141,14 @@ std::vector Model::generate() { void Model::createSearcher(SearcherConfig &config_) { if (searcher != nullptr) { delete searcher; } - if (config_.numBeams < 1) { - printf("numBeams should greater than or equal to 1.\n"); - exit(-1); - } else if (config_.numBeams == 1) { - if (config_.doSample) { - searcher = new SampleSearch(*decoder, config_); - } else { - searcher = new GreedySearch(*decoder, config_); - } - } else { + + GenerationMode genMode = getGenerationMode(config_); + if (genMode == GenerationMode::GREEDY_SEARCH) { + searcher = new GreedySearch(*decoder, config_); + } else if (genMode == GenerationMode::BEAM_SEARCH) { searcher = new BeamSearch(*decoder, config_); + } else if (genMode == GenerationMode::SAMPLE) { + searcher = new SampleSearch(*decoder, config_); } }