Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3329,6 +3329,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"--save-logits"},
string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),
[](common_params & params) {
params.save_logits = true;
}
).set_examples({LLAMA_EXAMPLE_EVAL_CALLBACK}));
add_opt(common_arg(
{"--logits-output-dir"}, "PATH",
string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()),
[](common_params & params, const std::string & value) {
params.logits_output_dir = value;
}
).set_examples({LLAMA_EXAMPLE_EVAL_CALLBACK}));

// presets
add_opt(common_arg(
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ enum llama_example {
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_FINETUNE,
LLAMA_EXAMPLE_FIT_PARAMS,
LLAMA_EXAMPLE_EVAL_CALLBACK,

LLAMA_EXAMPLE_COUNT,
};
Expand Down Expand Up @@ -369,6 +370,8 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
bool save_logits = false; // whether to save logits to files // NOLINT

std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
Expand Down
52 changes: 50 additions & 2 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
#include "ggml.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
#include <filesystem>
#include <fstream>

/**
* This the arbitrary data which will be passed to each callback.
Expand Down Expand Up @@ -160,6 +161,49 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
return true;
}

static void save_logits(llama_context * ctx, const llama_model * model, const common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);

// TODO: print tokens and and prompt.
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
const int n_logits = llama_vocab_n_tokens(vocab);

std::filesystem::create_directory(params.logits_output_dir);
std::filesystem::path model_path{params.model.path};
std::string model_name{model_path.stem().string()};
auto base_path = std::filesystem::path{params.logits_output_dir} / ("llamacpp-" + model_name);

// Save logits to binary file.
{
std::filesystem::path filepath{base_path.string() + ".bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
LOG_ERR("%s: error: failed to open binary output file\n", __func__);
return;
}
file.write(reinterpret_cast<const char*>(logits), n_logits * sizeof(float));
LOG("Logits saved to %s\n", filepath.c_str());
}

// Save logits to text file.
{
std::filesystem::path filepath{base_path.string() + ".txt"};
std::ofstream file{filepath};
if (!file) {
LOG_ERR("%s: error: failed to open text output file\n", __func__);
return;
}
for (int i = 0; i < n_logits; i++) {
file << i << ": " << logits[i] << "\n";
}
LOG("Logits saved to %s\n", filepath.c_str());
}

}

static bool run(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
Expand All @@ -186,7 +230,7 @@ int main(int argc, char ** argv) {

common_params params;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EVAL_CALLBACK)) {
return 1;
}

Expand Down Expand Up @@ -224,6 +268,10 @@ int main(int argc, char ** argv) {
return 1;
}

if (params.save_logits) {
save_logits(ctx, model, params);
}

LOG("\n");
llama_perf_context_print(ctx);

Expand Down
Loading