diff --git a/common/arg.cpp b/common/arg.cpp index 13020654982..a80f5cc1909 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3329,6 +3329,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({ LLAMA_EXAMPLE_FINETUNE })); + add_opt(common_arg( + {"--save-logits"}, + string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"), + [](common_params & params) { + params.save_logits = true; + } + ).set_examples({LLAMA_EXAMPLE_EVAL_CALLBACK})); + add_opt(common_arg( + {"--logits-output-dir"}, "PATH", + string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()), + [](common_params & params, const std::string & value) { + params.logits_output_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_EVAL_CALLBACK})); // presets add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 334372073a9..6386f3159ed 100644 --- a/common/common.h +++ b/common/common.h @@ -100,6 +100,7 @@ enum llama_example { LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_FIT_PARAMS, + LLAMA_EXAMPLE_EVAL_CALLBACK, LLAMA_EXAMPLE_COUNT, }; @@ -369,6 +370,8 @@ struct common_params { std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 408338f1afc..0f227c279de 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -5,9 +5,10 @@ #include "ggml.h" #include -#include #include #include +#include +#include /** * This the arbitrary data which will be passed to each callback. @@ -160,6 +161,49 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { return true; } +static void save_logits(llama_context * ctx, const llama_model * model, const common_params & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + // TODO: print tokens and and prompt. + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1); + const int n_logits = llama_vocab_n_tokens(vocab); + + std::filesystem::create_directory(params.logits_output_dir); + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + auto base_path = std::filesystem::path{params.logits_output_dir} / ("llamacpp-" + model_name); + + // Save logits to binary file. + { + std::filesystem::path filepath{base_path.string() + ".bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + LOG_ERR("%s: error: failed to open binary output file\n", __func__); + return; + } + file.write(reinterpret_cast(logits), n_logits * sizeof(float)); + LOG("Logits saved to %s\n", filepath.c_str()); + } + + // Save logits to text file. + { + std::filesystem::path filepath{base_path.string() + ".txt"}; + std::ofstream file{filepath}; + if (!file) { + LOG_ERR("%s: error: failed to open text output file\n", __func__); + return; + } + for (int i = 0; i < n_logits; i++) { + file << i << ": " << logits[i] << "\n"; + } + LOG("Logits saved to %s\n", filepath.c_str()); + } + +} + static bool run(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -186,7 +230,7 @@ int main(int argc, char ** argv) { common_params params; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EVAL_CALLBACK)) { return 1; } @@ -224,6 +268,10 @@ int main(int argc, char ** argv) { return 1; } + if (params.save_logits) { + save_logits(ctx, model, params); + } + LOG("\n"); llama_perf_context_print(ctx);