Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,11 +653,12 @@ ggml_tensor * llm_build_context::llm_build_ffn(
auto split_u = u->splits[id];
auto split_g = g->splits[id];
auto split_d = d->splits[id];
GGML_ASSERT((!split_u && !split_g && split_d) || (split_u && split_g && split_d));
GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d));
if (!split_u) continue;
auto cur = input;
if (ffn_norm && ffn_norm->extra) {
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
GGML_ASSERT(norm->splits[id]);
cur = llm_build_norm(ctx, input, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_inp_normed", il_cb);
}
Expand Down Expand Up @@ -1088,6 +1089,7 @@ llm_expert_gating_func_type gating_op,
auto cur = input;
if (ffn_norm) {
auto the_ffn_norm = ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[lctx.model.main_gpu] : ffn_norm;
GGML_ASSERT(the_ffn_norm);
cur = llm_build_norm(ctx, input, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_inp_normed", il);
}
Expand All @@ -1109,17 +1111,18 @@ llm_expert_gating_func_type gating_op,
gating_op, cb, il, graph);
cb(routed_out, "routed_out", il);
ggml_build_forward_expand(graph, routed_out);
//printf("Using non-split llm_build_moe_ffn for layer %d. n_before = %d, n_now = %d\n", il, n_before, graph->n_nodes);

if (up_shexp && gate_shexp && down_shexp) {
if (split_up_shexp) {
//printf("Using split ffn for shared experts in layer %d\n", il);
std::vector<ggml_tensor *> results(split_up_shexp->n_device);
std::vector<ggml_tensor *> results; results.reserve(split_up_shexp->n_device);
GGML_ASSERT(!split_up_b_shexp || split_up_b_shexp->n_device == split_up_shexp->n_device);
GGML_ASSERT(!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_shexp->n_device);
GGML_ASSERT(!split_down_b_shexp || split_down_b_shexp->n_device == split_up_shexp->n_device);
for (int id = 0; id < split_up_shexp->n_device; ++id) {
int il_cb = 1000*id + il;
GGML_ASSERT((split_up_shexp->splits[id] && split_gate_shexp->splits[id] && split_down_shexp->splits[id]) ||
(!split_up_shexp->splits[id] && !split_gate_shexp->splits[id] && !split_down_shexp->splits[id]));
if (!split_up_shexp->splits[id]) continue;
auto the_ffn_norm = ffn_norm ? ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[id] : ffn_norm : nullptr;
auto shared_out = llm_build_ffn(ctx, lctx, the_ffn_norm, input,
split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr,
Expand All @@ -1130,17 +1133,19 @@ llm_expert_gating_func_type gating_op,
if (shared_out->ne[1] > 32) {
shared_out = ggml_cast(ctx, shared_out, GGML_TYPE_F16);
}
results[id] = shared_out;
results.push_back(shared_out);
}
cur = ggml_add(ctx, results[0], results[1]);
if (cur->ne[1] > 32) {
// Force a graph split
GGML_ASSERT(!results.empty());
if (results.size() == 1) {
cur = results.front();
} else {
cur = ggml_add(ctx, results[0], results[1]);
cur->op_params[0] = 0xff;
}
cb(cur, "ffn_shared_combined", il);
for (int id = 2; id < int(results.size()); ++id) {
cur = ggml_add(ctx, cur, results[id]);
cb(cur, "ffn_shared_combined", il);
for (int id = 2; id < int(results.size()); ++id) {
cur = ggml_add(ctx, cur, results[id]);
cb(cur, "ffn_shared_combined", il);
}
}
if (routed_out->ne[1] > 32) {
auto routed_out_f16 = ggml_cast(ctx, routed_out, GGML_TYPE_F16);
Expand All @@ -1150,7 +1155,6 @@ llm_expert_gating_func_type gating_op,
}
cb(cur, "ffn_out", il);
} else {
//printf("Using non-split ffn for shared experts in layer %d\n", il);
auto shared_out = llm_build_ffn(ctx, lctx, nullptr, cur,
up_shexp, up_b_shexp, nullptr,
gate_shexp, gate_b_shexp, nullptr,
Expand All @@ -1170,14 +1174,17 @@ llm_expert_gating_func_type gating_op,
}
GGML_ASSERT(split_up_exps && split_gate_exps && split_down_exps);
GGML_ASSERT(split_up_exps->n_device == split_gate_exps->n_device && split_up_exps->n_device == split_down_exps->n_device);
std::vector<ggml_tensor *> results(split_up_exps->n_device);
std::vector<ggml_tensor *> results; results.reserve(split_up_exps->n_device);
GGML_ASSERT((!split_up_shexp && !split_gate_shexp && !split_down_shexp) ||
( split_up_shexp && split_gate_shexp && split_down_shexp));
auto split_gate_inp = (ggml_split_tensor_t *)gate_inp->extra;
GGML_ASSERT(split_gate_inp && split_gate_inp->n_device == split_up_exps->n_device);
auto split_exp_probs_b = exp_probs_b ? (ggml_split_tensor_t *)exp_probs_b->extra : nullptr;
GGML_ASSERT(!split_exp_probs_b || split_exp_probs_b->n_device == split_up_exps->n_device);
for (int id = 0; id < split_up_exps->n_device; ++id) {
GGML_ASSERT((split_up_exps->splits[id] && split_gate_exps->splits[id] && split_down_exps->splits[id]) ||
(!split_up_exps->splits[id] && !split_gate_exps->splits[id] && !split_down_exps->splits[id]));
if (!split_up_exps->splits[id]) continue;
int il_cb = 1000*(id + 1) + il;
auto cur = input;
if (ffn_norm) {
Expand Down Expand Up @@ -1220,8 +1227,9 @@ llm_expert_gating_func_type gating_op,
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
cb(cur, "ffn_out_f16", il_cb);
}
results[id] = cur;
results.push_back(cur);
}
GGML_ASSERT(!results.empty());
if (results.size() == 1) return results.front();

auto cur = ggml_add(ctx, results[0], results[1]);
Expand Down Expand Up @@ -1660,10 +1668,15 @@ static ggml_tensor * build_output(llama_context & lctx, ggml_context * ctx, ggml
}
cb(o.back(), "output", id);
}
if (o.size() == 1) cur = o.front();
cur = ggml_concat(ctx, o[0], o[1], 0);
for (int id = 2; id < int(o.size()); ++id) {
cur = ggml_concat(ctx, cur, o[id], 0);
GGML_ASSERT(!o.empty());
if (o.size() == 1) {
cur = o.front();
}
else {
cur = ggml_concat(ctx, o[0], o[1], 0);
for (int id = 2; id < int(o.size()); ++id) {
cur = ggml_concat(ctx, cur, o[id], 0);
}
}
} else {
if (output_norm) {
Expand Down Expand Up @@ -9357,6 +9370,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
ggml_build_forward_expand(gf, cur);
attn.push_back(cur);
}
GGML_ASSERT(!attn.empty());
if (attn.size() == 1) return attn.front();
auto cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il);
Expand All @@ -9365,10 +9379,6 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cur = ggml_add(ctx0, cur, attn[id]);
cb(cur, "combine_attn", il);
}
// TODO: for more than 2 GPUs, do we need to add another forced graph split?
//if (attn.size() > 2) {
// cur->op_params[0] = 0xff;
//}
return cur;
}
}
Expand Down