1818#include " INIReader.h"
1919#include " chatglm2.h"
2020
21- template <typename WeiT, typename NormT >
22- ChatGLM2<WeiT, NormT >::ChatGLM2(const std::string &modelPath, const std::string &modelType)
23- : CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, NormT , float , float , float , true >,
24- ChatGLM2MLP<WeiT, float , float , float , NormT , true >>(modelPath, modelType) {
21+ template <typename WeiT>
22+ ChatGLM2<WeiT>::ChatGLM2(const std::string &modelPath, const std::string &modelType)
23+ : CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, RmsNorm , float , float , float , true >,
24+ ChatGLM2MLP<WeiT, float , float , float , RmsNorm , true >>(modelPath, modelType) {
2525 this ->positionIds = nullptr ;
2626 this ->posBufSize = 0 ;
2727
@@ -36,15 +36,15 @@ ChatGLM2<WeiT, NormT>::ChatGLM2(const std::string &modelPath, const std::string
3636 setFinalLnWeight (modelPath);
3737}
3838
39- template <typename WeiT, typename NormT >
40- ChatGLM2<WeiT, NormT >::~ChatGLM2 () {
39+ template <typename WeiT>
40+ ChatGLM2<WeiT>::~ChatGLM2 () {
4141 delete embedding;
4242
4343 if (positionIds) { free (positionIds); }
4444}
4545
46- template <typename WeiT, typename NormT >
47- void ChatGLM2<WeiT, NormT >::setEmbeddingWeights(const std::string &modelPath) {
46+ template <typename WeiT>
47+ void ChatGLM2<WeiT>::setEmbeddingWeights(const std::string &modelPath) {
4848 int vocabSize = embedding->getVocabSize ();
4949 int hiddenSize = embedding->getHiddenSize ();
5050
@@ -57,8 +57,8 @@ void ChatGLM2<WeiT, NormT>::setEmbeddingWeights(const std::string &modelPath) {
5757 free (tokenEmb);
5858}
5959
60- template <typename WeiT, typename NormT >
61- void ChatGLM2<WeiT, NormT >::setFinalLnWeight(const std::string &modelPath) {
60+ template <typename WeiT>
61+ void ChatGLM2<WeiT>::setFinalLnWeight(const std::string &modelPath) {
6262 int hiddenSize = embedding->getHiddenSize ();
6363
6464 float *gamma = (float *)malloc (hiddenSize * sizeof (float ));
@@ -85,8 +85,8 @@ void ChatGLM2<WeiT, NormT>::setFinalLnWeight(const std::string &modelPath) {
8585// attention_mask = (attention_mask < 0.5).bool()
8686//
8787// return attention_mask
88- template <typename WeiT, typename NormT >
89- void ChatGLM2<WeiT, NormT >::prepareAttnMask(int *ids, int step) {
88+ template <typename WeiT>
89+ void ChatGLM2<WeiT>::prepareAttnMask(int *ids, int step) {
9090 DecoderContext *ctx = this ->getContext ();
9191 int seqLen = ctx->inputSeqLen ;
9292 int sizeRequired = ctx->batchSize * seqLen * seqLen;
@@ -127,13 +127,13 @@ void ChatGLM2<WeiT, NormT>::prepareAttnMask(int *ids, int step) {
127127 }
128128}
129129
130- template <typename WeiT, typename NormT >
131- void ChatGLM2<WeiT, NormT >::embeddingForward(int *ids, float *output, int batchSize, int seqLen) {
130+ template <typename WeiT>
131+ void ChatGLM2<WeiT>::embeddingForward(int *ids, float *output, int batchSize, int seqLen) {
132132 embedding->forward (ids, output, batchSize, seqLen);
133133}
134134
135- template <typename WeiT, typename NormT >
136- void ChatGLM2<WeiT, NormT >::lastLayerNormForward(float *input, float *output, int rows) {
135+ template <typename WeiT>
136+ void ChatGLM2<WeiT>::lastLayerNormForward(float *input, float *output, int rows) {
137137 finalLN.forward (input, output, rows);
138138}
139139
@@ -147,8 +147,8 @@ void ChatGLM2<WeiT, NormT>::lastLayerNormForward(float *input, float *output, in
147147// batch_size, seq_length = input_ids.shape
148148// position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
149149// return position_ids
150- template <typename WeiT, typename NormT >
151- int *ChatGLM2<WeiT, NormT >::getPositionIds(int *ids, int batchSize, int seqLen, int step) {
150+ template <typename WeiT>
151+ int *ChatGLM2<WeiT>::getPositionIds(int *ids, int batchSize, int seqLen, int step) {
152152 // Prepare buffer
153153 int sizeNeeded = (batchSize * seqLen + 63 ) / 64 * 64 ; // position_ids + block_position_ids
154154 if (posBufSize < sizeNeeded) {
0 commit comments