Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3424,3 +3424,83 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
GGML_ABORT("Function is not implemented.");
}
}

// Helper: number of elements in a row (width) and number of rows
static inline int64_t ggml_nrows_safe(const ggml_tensor * t) {
// ggml_nrows typically computes product of ne[1..], keep same semantics.
int64_t n = 1;
for (int i = 1; i < GGML_MAX_DIMS; ++i) n *= t->ne[i];
return n;
}

void ggml_cann_cross_entropy_loss(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// dst: scalar result tensor whose src[0]=logits, src[1]=labels
const ggml_tensor * src0 = dst->src[0]; // logits: [ne00 x nrows]
const ggml_tensor * src1 = dst->src[1]; // labels: same shape as logits

// Basic checks (mirror CPU/CUDA assumptions)
if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32) {
// fallback or error: not supported types
GGML_ASSERT(false && "ggml_cann_cross_entropy_loss: only F32 supported");
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));

const int64_t nclasses = src0->ne[0];
const int64_t nrows = ggml_nrows(src0); // helper from ggml

const int64_t n_elems = nclasses * nrows;
const size_t nbytes = (size_t)n_elems * sizeof(float);

// Allocate host buffers
std::vector<float> host_logits((size_t)n_elems);
std::vector<float> host_labels((size_t)n_elems);

// Device -> Host copies (synchronous)
aclError err;
err = aclrtMemcpy(host_logits.data(), nbytes,
src0->data, nbytes,
ACL_MEMCPY_DEVICE_TO_HOST);
GGML_ASSERT(err == ACL_ERROR_NONE);

err = aclrtMemcpy(host_labels.data(), nbytes,
src1->data, nbytes,
ACL_MEMCPY_DEVICE_TO_HOST);
GGML_ASSERT(err == ACL_ERROR_NONE);

// Compute cross-entropy loss: mean over rows of sum_i [-label_i * log(softmax_i)]
double loss_sum = 0.0;
size_t offset = 0;
for (int64_t r = 0; r < nrows; ++r) {
// find max for numerical stability
float maxv = -std::numeric_limits<float>::infinity();
for (int64_t j = 0; j < nclasses; ++j) {
float v = host_logits[offset + j];
if (v > maxv) maxv = v;
}
// compute sum exp
double sumexp = 0.0;
for (int64_t j = 0; j < nclasses; ++j) {
double val = std::exp((double)host_logits[offset + j] - (double)maxv);
sumexp += val;
}
double logsum = std::log(sumexp) + (double)maxv;
// accumulate cross-entropy for this row
for (int64_t j = 0; j < nclasses; ++j) {
double lp = (double)host_labels[offset + j];
double logit = (double)host_logits[offset + j];
loss_sum += - lp * (logit - logsum);
}
offset += (size_t)nclasses;
}

float loss = (float)(loss_sum / (double)nrows);

// copy scalar loss back to device (dst->data)
err = aclrtMemcpy(dst->data, sizeof(float),
&loss, sizeof(float),
ACL_MEMCPY_HOST_TO_DEVICE);
GGML_ASSERT(err == ACL_ERROR_NONE);
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,8 @@ void ggml_cann_op_unary_gated(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst);

void ggml_cann_cross_entropy_loss(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/**
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
*
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cann_cross_entropy_loss(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2493,6 +2496,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_CROSS_ENTROPY_LOSS:
return true;
case GGML_OP_SCALE:
float bias;
Expand Down