diff --git a/src/models/models.cpp b/src/models/models.cpp index 197973bf..165b510a 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -783,13 +783,13 @@ std::tuple Model::forward(bool logits_all) { // Sync and gather all logits float *outBuf = std::get<0>(result); - int works = messenger.getSize(); - int splitSize = vocabSize / works; - std::vector recvCount(works); - std::vector splitSizes(works); - for (int i = 0; i < works; i++) { + int workers = messenger.getSize(); + int splitSize = vocabSize / workers; + std::vector recvCount(workers); + std::vector splitSizes(workers); + for (int i = 0; i < workers; i++) { splitSizes[i] = splitSize; - if (i < vocabSize % works) { splitSizes[i]++; } + if (i < vocabSize % workers) { splitSizes[i]++; } recvCount[i] = splitSizes[i] * totalSeqSize; } // warning: vocabSize * totalSeqSize may exceed the range of int when seq and batch size is large. @@ -799,9 +799,9 @@ std::tuple Model::forward(bool logits_all) { // Reorder int offset = 0; - for (int i = 0; i < works; ++i) { + for (int i = 0; i < workers; ++i) { for (int j = 0; j < totalSeqSize; ++j) { - memcpy(logits.data() + (i * offset + j * vocabSize), + memcpy(logits.data() + (offset + j * vocabSize), logitsRecvBuf.data() + offset * totalSeqSize + j * splitSizes[i], splitSizes[i] * sizeof(float)); }