Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions examples/debug/debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,21 @@ struct callback_data {
}
};

static bool has_pooling(llama_context * ctx) {
switch (llama_pooling_type(ctx)) {
case LLAMA_POOLING_TYPE_NONE:
case LLAMA_POOLING_TYPE_UNSPECIFIED:
return false;
default:
return true;
}
}

struct output_data {
float * data_ptr = nullptr;
int data_size = 0;
std::string type_suffix;
std::vector<float> storage;
std::vector<float> embd_norm;
std::string prompt;
std::vector<llama_token> tokens;

Expand All @@ -73,24 +83,32 @@ struct output_data {
prompt = params.prompt;

if (params.embedding) {
const int n_embd = llama_model_n_embd_out(model);
const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
const int n_embd_count = pooling_enabled ? 1 : tokens.size();
const int n_embeddings = n_embd * n_embd_count;

float * embeddings;
if (pooling_enabled) {
embeddings = llama_get_embeddings_seq(ctx, 0);
storage.resize(n_embeddings);
common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
embeddings = storage.data();
} else {
embeddings = llama_get_embeddings(ctx);
const int n_embd = llama_model_n_embd_out(model);
const bool pooling = has_pooling(ctx);
const int n_embd_count = pooling ? 1 : tokens.size();
const int n_floats = n_embd * n_embd_count;

float * embd_raw = pooling ? llama_get_embeddings_seq(ctx, 0) : llama_get_embeddings(ctx);
if (embd_raw == nullptr) {
throw std::runtime_error("failed to get embeddings from the model");
}

data_ptr = embeddings;
data_size = n_embeddings;
LOG_DBG("pooling_enabled: %s\n", pooling ? "true" : "false");
LOG_DBG("n_embd: %d\n", n_embd);
LOG_DBG("n_floats: %d\n", n_floats);
LOG_DBG("n_embd_count: %d\n", n_embd_count);

data_ptr = embd_raw;
data_size = n_floats;
type_suffix = "-embeddings";

if (params.embd_normalize >= 0) {
embd_norm.resize(n_floats);
for (int i = 0; i < n_embd_count; i++) {
common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize);
}
data_ptr = embd_norm.data();
}
} else {
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
Expand Down
Loading