From fa7fda2e4af5e96a44b6623cb6243d5fba027676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B0=91=E5=87=A1?= <2501112072@cninfer03.localdomain> Date: Mon, 24 Nov 2025 21:16:20 +0800 Subject: [PATCH] cross entropy loss --- ggml/src/ggml-cann/aclnn_ops.cpp | 80 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 2 + ggml/src/ggml-cann/ggml-cann.cpp | 4 ++ 3 files changed, 86 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index bc33b99d96e..6ea30bb80b2 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3424,3 +3424,83 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ GGML_ABORT("Function is not implemented."); } } + +// Helper: number of elements in a row (width) and number of rows +static inline int64_t ggml_nrows_safe(const ggml_tensor * t) { + // ggml_nrows typically computes product of ne[1..], keep same semantics. + int64_t n = 1; + for (int i = 1; i < GGML_MAX_DIMS; ++i) n *= t->ne[i]; + return n; +} + +void ggml_cann_cross_entropy_loss(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + // dst: scalar result tensor whose src[0]=logits, src[1]=labels + const ggml_tensor * src0 = dst->src[0]; // logits: [ne00 x nrows] + const ggml_tensor * src1 = dst->src[1]; // labels: same shape as logits + + // Basic checks (mirror CPU/CUDA assumptions) + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32) { + // fallback or error: not supported types + GGML_ASSERT(false && "ggml_cann_cross_entropy_loss: only F32 supported"); + return; + } + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int64_t nclasses = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); // helper from ggml + + const int64_t n_elems = nclasses * nrows; + const size_t nbytes = (size_t)n_elems * sizeof(float); + + // Allocate host buffers + std::vector host_logits((size_t)n_elems); + std::vector host_labels((size_t)n_elems); + + // Device -> Host copies (synchronous) + aclError err; + err = aclrtMemcpy(host_logits.data(), nbytes, + src0->data, nbytes, + ACL_MEMCPY_DEVICE_TO_HOST); + GGML_ASSERT(err == ACL_ERROR_NONE); + + err = aclrtMemcpy(host_labels.data(), nbytes, + src1->data, nbytes, + ACL_MEMCPY_DEVICE_TO_HOST); + GGML_ASSERT(err == ACL_ERROR_NONE); + + // Compute cross-entropy loss: mean over rows of sum_i [-label_i * log(softmax_i)] + double loss_sum = 0.0; + size_t offset = 0; + for (int64_t r = 0; r < nrows; ++r) { + // find max for numerical stability + float maxv = -std::numeric_limits::infinity(); + for (int64_t j = 0; j < nclasses; ++j) { + float v = host_logits[offset + j]; + if (v > maxv) maxv = v; + } + // compute sum exp + double sumexp = 0.0; + for (int64_t j = 0; j < nclasses; ++j) { + double val = std::exp((double)host_logits[offset + j] - (double)maxv); + sumexp += val; + } + double logsum = std::log(sumexp) + (double)maxv; + // accumulate cross-entropy for this row + for (int64_t j = 0; j < nclasses; ++j) { + double lp = (double)host_labels[offset + j]; + double logit = (double)host_logits[offset + j]; + loss_sum += - lp * (logit - logsum); + } + offset += (size_t)nclasses; + } + + float loss = (float)(loss_sum / (double)nrows); + + // copy scalar loss back to device (dst->data) + err = aclrtMemcpy(dst->data, sizeof(float), + &loss, sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE); + GGML_ASSERT(err == ACL_ERROR_NONE); +} \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 5c510cc9932..f648cc9e738 100755 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1176,6 +1176,8 @@ void ggml_cann_op_unary_gated( std::function unary_op, ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_cross_entropy_loss(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary. * diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cb8af42ebf9..93d3a2e3121 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1881,6 +1881,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_CROSS_ENTROPY_LOSS: + ggml_cann_cross_entropy_loss(ctx, dst); + break; default: return false; } @@ -2493,6 +2496,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_MEAN: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: + case GGML_OP_CROSS_ENTROPY_LOSS: return true; case GGML_OP_SCALE: float bias;