Skip to content

Commit 70c4f1e

Browse files
authored
Merge pull request PaddlePaddle#13 from bmers/9_19_llama
9 19 llama
2 parents fcddea6 + ad90acb commit 70c4f1e

16 files changed

+364
-139
lines changed

backends/npu/custom_op/llama_decoder_layer_parallel_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ std::vector<paddle::Tensor> LlaMaDecoderLayerParallelOp(
186186
head_dim,
187187
0,
188188
0,
189-
g_llamadecoderLayerId,
189+
0,
190190
2,
191191
true,
192192
g_atbSeqLen.kv_seq_len_param,
@@ -263,7 +263,7 @@ PD_BUILD_OP(llama_decoder_layer_parallel)
263263
"AttentionMask",
264264
"Cache_KV",
265265
"SeqLength"})
266-
.Outputs({"Out", "PresentKey", "PresentValue"})
266+
.Outputs({"Out", "PresentKV"})
267267
.Attrs({"rmsNormEps: float",
268268
"headDim: int",
269269
"headNum: int"})

backends/npu/custom_op/llama_encoder_layer_parallel_op.cc

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
#include <hccl/hccl.h>
1717
#include <hccl/hccl_types.h>
1818
#include "llama_layer_parallel_op.h"
19+
#include "llama_layer/llama_fusion_parallel_operation.h"
1920
#include "llama_layer/llama_encoder_parallel_operation.h"
2021
#include "paddle/extension.h"
22+
#include "kernels/funcs/format_utils.h"
23+
#include "kernels/funcs/npu_funcs.h"
24+
#include "kernels/funcs/npu_op_runner.h"
2125

2226
std::shared_ptr<PpAtbLlaMaEncoderLayerParallelOp> g_llaMaEncoderLayerParallelOp;
2327
static int32_t g_llamaEncoderLayerId = 0;
@@ -34,6 +38,10 @@ void PerpareLlaMaEncoderLayerInputs(
3438
const paddle::Tensor &positionIDs,
3539
const paddle::Tensor &cos_sin_table,
3640
const paddle::Tensor &attention_mask,
41+
const paddle::Tensor &cache_key_value,
42+
const paddle::Tensor &kv_seq_len,
43+
const paddle::Tensor &q_seq_len,
44+
phi::DenseTensor &layer_id_dense,
3745
std::vector<const phi::DenseTensor *> &inputs) {
3846

3947
auto hidden_tensor = static_cast<const phi::DenseTensor *>(hidden.impl().get());
@@ -46,6 +54,9 @@ void PerpareLlaMaEncoderLayerInputs(
4654
auto positionIDs_tensor = static_cast<const phi::DenseTensor *>(positionIDs.impl().get());
4755
auto cos_sin_table_tensor = static_cast<const phi::DenseTensor *>(cos_sin_table.impl().get());
4856
auto attention_mask_tensor = static_cast<const phi::DenseTensor *>(attention_mask.impl().get());
57+
auto cache_key_value_tensor = static_cast<const phi::DenseTensor *>(cache_key_value.impl().get());
58+
auto kv_seq_len_tensor = static_cast<const phi::DenseTensor *>(kv_seq_len.impl().get());
59+
auto q_seq_len_tensor = static_cast<const phi::DenseTensor *>(q_seq_len.impl().get());
4960

5061
inputs.push_back(hidden_tensor);
5162
inputs.push_back(norm_weight_tensor);
@@ -57,6 +68,10 @@ void PerpareLlaMaEncoderLayerInputs(
5768
inputs.push_back(positionIDs_tensor);
5869
inputs.push_back(cos_sin_table_tensor);
5970
inputs.push_back(attention_mask_tensor);
71+
inputs.push_back(cache_key_value_tensor);
72+
inputs.push_back(kv_seq_len_tensor);
73+
inputs.push_back(q_seq_len_tensor);
74+
inputs.push_back(&layer_id_dense);
6075
}
6176

6277
PpAtbLlaMaEncoderLayerParallelOp::PpAtbLlaMaEncoderLayerParallelOp(
@@ -77,6 +92,8 @@ std::vector<paddle::Tensor> LlaMaEncoderLayerParallelOp(
7792
const paddle::Tensor &positionIDs,
7893
const paddle::Tensor &cos_sin_table,
7994
const paddle::Tensor &attention_mask,
95+
const paddle::Tensor &cache_key_value,
96+
const paddle::Tensor &kv_seq_len,
8097
float rmsNormEps,
8198
int headDim,
8299
int headNum) {
@@ -112,34 +129,30 @@ std::vector<paddle::Tensor> LlaMaEncoderLayerParallelOp(
112129
layerout_tensor->Resize(phi::make_ddim(hidden.shape()));
113130
dev_ctx->Alloc(layerout_tensor.get(), data_type);
114131

115-
std::shared_ptr<phi::DenseTensor> key_tensor =
116-
std::make_shared<phi::DenseTensor>();
117-
key_tensor->Resize(phi::make_ddim(key_shape));
118-
dev_ctx->Alloc(key_tensor.get(), data_type);
119-
120-
std::shared_ptr<phi::DenseTensor> value_tensor =
121-
std::make_shared<phi::DenseTensor>();
122-
value_tensor->Resize(phi::make_ddim(value_shape));
123-
dev_ctx->Alloc(value_tensor.get(), data_type);
124-
125132
std::vector<const phi::DenseTensor *> outputs;
126133
outputs.push_back(layerout_tensor.get());
127-
outputs.push_back(key_tensor.get());
128-
outputs.push_back(value_tensor.get());
129134

130135
if (!g_llaMaEncoderLayerParallelOp) {
131136
std::cout << "Run In Encoder Parallel layernum: " << layer_num << " head_num: " << head_num << "head_dim: " << head_dim << std::endl;
132137
g_llaMaEncoderLayerParallelOp.reset(new PpAtbLlaMaEncoderLayerParallelOp("LlaMaEncoderLayerParallelOp", layer_num));
133138

134139
atb::Operation *op = nullptr;
135-
LlamaLayerEncoderParallelParam param = {rmsNormEps,
140+
LlamaLayerFusionParallelParam param = {rmsNormEps,
136141
head_num,
137142
head_dim,
138143
0,
139144
0,
145+
0,
146+
2,
147+
true,
148+
{0},
149+
{0},
140150
comm};
141-
CreateLlamaLayerEncoderParallelOperation(param, &op);
151+
LlamaLayerFusionParallelOperation(param, &op);
142152
g_llaMaEncoderLayerParallelOp->operation_.reset(op);
153+
std::vector<int32_t> layer_id_vec(1, 0);
154+
custom_kernel::TensorFromVector(*dev_ctx, layer_id_vec,
155+
*dev_ctx, &(g_llaMaEncoderLayerParallelOp->layerIdTensor));
143156
}
144157

145158
std::vector<const phi::DenseTensor *> inputs;
@@ -153,6 +166,10 @@ std::vector<paddle::Tensor> LlaMaEncoderLayerParallelOp(
153166
positionIDs,
154167
cos_sin_table,
155168
attention_mask,
169+
cache_key_value,
170+
kv_seq_len, // token offset即kv_seq_len
171+
kv_seq_len, // 增量q_seq_len,始终为1
172+
g_llaMaEncoderLayerParallelOp->layerIdTensor,
156173
inputs);
157174

158175
g_llaMaEncoderLayerParallelOp->Execute(stream, inputs, outputs);
@@ -163,8 +180,7 @@ std::vector<paddle::Tensor> LlaMaEncoderLayerParallelOp(
163180
}
164181

165182
return {paddle::Tensor(layerout_tensor),
166-
paddle::Tensor(key_tensor),
167-
paddle::Tensor(value_tensor)};
183+
paddle::Tensor(cache_key_value)};
168184
}
169185

170186
std::vector<std::vector<int64_t>> LlaMaEncoderLayerOpInferShape(
@@ -178,8 +194,10 @@ std::vector<std::vector<int64_t>> LlaMaEncoderLayerOpInferShape(
178194
const std::vector<int64_t> &positionIDs_shape,
179195
const std::vector<int64_t> &cos_sin_table_shape,
180196
const std::vector<int64_t> &attention_mask_shape,
197+
const std::vector<int64_t> &cacheKV_shape,
198+
const std::vector<int64_t> &seq_len_shape,
181199
float rmsNormEps,
182-
int headDim,
200+
int headDim,
183201
int headNum) {
184202

185203
int32_t head_num = headNum; /* TODO:64个,写死8卡 */
@@ -196,7 +214,7 @@ std::vector<std::vector<int64_t>> LlaMaEncoderLayerOpInferShape(
196214
value_shape.push_back(hidden_shape.at(1));
197215
value_shape.push_back(head_num);
198216
value_shape.push_back(head_dim);
199-
return {hidden_shape, key_shape, value_shape};
217+
return {hidden_shape, cacheKV_shape};
200218
}
201219

202220
PD_BUILD_OP(llama_encoder_layer_parallel)
@@ -209,8 +227,10 @@ PD_BUILD_OP(llama_encoder_layer_parallel)
209227
"MlpDownWeight",
210228
"PositionIDs",
211229
"CosSinTable",
212-
"AttentionMask"})
213-
.Outputs({"Out", "PresentKey", "PresentValue"})
230+
"AttentionMask",
231+
"Cache_KV",
232+
"SeqLength"})
233+
.Outputs({"Out", "PresentKV"})
214234
.Attrs({"rmsNormEps: float",
215235
"headDim: int",
216236
"headNum: int"})

backends/npu/custom_op/llama_layer/llama_encoder_parallel_operation.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ atb::Status CreateLlamaLayerEncoderParallelOperation(const LlamaLayerEncoderPara
103103
CreateLlamaPositionEmbedding1DSplitOperation(positionEmbedding1dSplitQParam, &qPositionEmbeddingNode.operation);
104104
qPositionEmbeddingNode.inTensorIds = {INTERMIDATE_MIXEDQ, IN_POSITIONIDS, INTERMIDATE_CASTCOS, INTERMIDATE_CASTSIN};
105105
qPositionEmbeddingNode.outTensorIds = {INTERMIDATE_POSITIONEMBEDQ};
106+
qPositionEmbeddingNode.inTensorReshapeFuncs.resize(qPositionEmbeddingNode.inTensorIds.size());
106107
qPositionEmbeddingNode.inTensorReshapeFuncs.at(2) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
107108
newShape.dimNum = 4; // dimNum: 4
108109
newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
@@ -123,6 +124,7 @@ atb::Status CreateLlamaLayerEncoderParallelOperation(const LlamaLayerEncoderPara
123124
CreateLlamaPositionEmbedding1DSplitOperation(positionEmbedding1dSplitKParam, &kPositionEmbeddingNode.operation);
124125
kPositionEmbeddingNode.inTensorIds = {INTERMIDATE_MIXEDK, IN_POSITIONIDS, INTERMIDATE_CASTCOS, INTERMIDATE_CASTSIN};
125126
kPositionEmbeddingNode.outTensorIds = {INTERMIDATE_POSITIONEMBEDK};
127+
kPositionEmbeddingNode.inTensorReshapeFuncs.resize(kPositionEmbeddingNode.inTensorIds.size());
126128
kPositionEmbeddingNode.inTensorReshapeFuncs.at(2) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
127129
newShape.dimNum = 4; // dimNum: 4
128130
newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];

backends/npu/custom_op/llama_layer/llama_fusion_parallel_operation.cpp

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
static const uint64_t IN_TENSOR_COUNT = 14;
2323
static const uint64_t OUT_TENSOR_COUNT = 1;
24-
static const uint64_t INTERMEDIATE_TENSOR_COUNT = 16;
25-
static const uint64_t NODE_COUNT = 12;
24+
static const uint64_t INTERMEDIATE_TENSOR_COUNT = 17;
25+
static const uint64_t NODE_COUNT = 13;
2626

2727
atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelParam &param,
2828
atb::Operation **operation)
@@ -36,6 +36,7 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
3636
size_t nodeId = 0;
3737
atb::Node &inputNormNode = opGraph.nodes.at(nodeId++);
3838
atb::Node &mixdQKVLinearNode = opGraph.nodes.at(nodeId++);
39+
atb::Node &castInNode = opGraph.nodes.at(nodeId++);
3940
atb::Node &cosSinSplitNode = opGraph.nodes.at(nodeId++);
4041
atb::Node &ropeNode = opGraph.nodes.at(nodeId++);
4142
atb::Node &cacheKVSplitNode = opGraph.nodes.at(nodeId++);
@@ -47,24 +48,45 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
4748
atb::Node &mlpLinearParallelNode = opGraph.nodes.at(nodeId++);
4849
atb::Node &mlpResidualAddNode = opGraph.nodes.at(nodeId++);
4950

51+
// [bs, seq_len, hidden_size]
5052
atb::infer::RmsNormParam inputNormParam;
5153
inputNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM;
5254
inputNormParam.normParam.epsilon = param.rmsNormEps;
5355
atb::CreateOperation(inputNormParam, &inputNormNode.operation);
5456
inputNormNode.inTensorIds = {IN_HIDDENSTATES, IN_NORMWEIGHT};
5557
inputNormNode.outTensorIds = {INTERMIDATE_INPUTNORMOUT};
58+
inputNormNode.inTensorReshapeFuncs.resize(inputNormNode.inTensorIds.size());
59+
inputNormNode.inTensorReshapeFuncs.at(0) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
60+
if (oldShape.dimNum == 3) {
61+
newShape = oldShape;
62+
} else if (oldShape.dimNum == 2) {
63+
newShape.dimNum = 3; // 增量阶段
64+
newShape.dims[0] = oldShape.dims[0];
65+
newShape.dims[1] = 1;
66+
newShape.dims[2] = oldShape.dims[1];
67+
}
68+
};
5669

70+
// [bs, seq_len, hidden_size] * [3 * hidden_size / card_num, hidden_size] -> [bs,seq_len, hidden_size / card_num]
5771
MultiLayerLinearParam multiLayerLinearParam;
5872
multiLayerLinearParam.transpose = param.transpose;
5973
CreateLlamaMultiLayerLinearOperation(multiLayerLinearParam, &mixdQKVLinearNode.operation);
6074
mixdQKVLinearNode.inTensorIds = {INTERMIDATE_INPUTNORMOUT, IN_QKVMIXDWEIGHT};
6175
mixdQKVLinearNode.outTensorIds = {INTERMIDATE_MIXEDQ, INTERMIDATE_MIXEDK, INTERMIDATE_MIXEDV};
6276

77+
atb::infer::ElewiseParam castParam;
78+
castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST;
79+
CreateOperation(castParam, &castInNode.operation);
80+
castInNode.inTensorIds = {IN_COS_SIN_TABLE};
81+
castInNode.outTensorIds = {INTERNAL_CAST_COS_SIN_TABLE};
82+
83+
// [2, head_dim, 1, seq_len, 1] ? [1, head_dim, 1, seq_len, 1]
6384
atb::infer::SplitParam splitParam = {0, 2};
6485
atb::CreateOperation(splitParam, &cosSinSplitNode.operation);
65-
cosSinSplitNode.inTensorIds = {IN_COS_SIN_TABLE};
86+
cosSinSplitNode.inTensorIds = {INTERNAL_CAST_COS_SIN_TABLE};
6687
cosSinSplitNode.outTensorIds = {INTERMIDATE_CASTCOS, INTERMIDATE_CASTSIN};
6788

89+
// 全量:
6890
llamaPositionEmbedding1DSplitFusionParam positionEmbedding1dFusionParam;
6991
positionEmbedding1dFusionParam.headNum = param.headNum;
7092
positionEmbedding1dFusionParam.rotaryCoeff = param.rotaryCoeff;
@@ -96,12 +118,17 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
96118
newShape.dims[2] = oldShape.dims[3];
97119
newShape.dims[3] = oldShape.dims[4];
98120
};
121+
ropeNode.inTensorReshapeFuncs.at(5) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
122+
newShape.dimNum = 1; // dimNum: 4
123+
newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
124+
};
99125

126+
// [2, 1, head_num / card_num, max_length, head_dim]
100127
atb::infer::SplitParam splitKVParam = {0, 2};
101128
atb::CreateOperation(splitKVParam, &cacheKVSplitNode.operation);
102129
cacheKVSplitNode.inTensorIds = {IN_CACHE_KV};
103130
cacheKVSplitNode.outTensorIds = {INTERMIDATE_CACHEK, INTERMIDATE_CACHEV};
104-
131+
//全量 [1, 1, 4, 128] [1, 1, 4, 128] [1, 1, 4, 128] [1, 4, 2048, 128] [1, 4, 2048, 128] [1, 4, 2048, 2048] [1, 1] [1, 1] [1]
105132
atb::infer::SelfAttentionParam selfAttentionKvCacheParam;
106133
selfAttentionKvCacheParam.headDim = param.headDim;
107134
selfAttentionKvCacheParam.headNum = param.headNum;
@@ -120,40 +147,58 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
120147
selfAttentionKvCacheNode.inTensorReshapeFuncs.resize(selfAttentionKvCacheNode.inTensorIds.size());
121148
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(0) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
122149
newShape.dimNum = 4; // dimNum: 4
123-
newShape.dims[0] = oldShape.dims[0];
150+
newShape.dims[0] = 1;
124151
newShape.dims[1] = oldShape.dims[0];
125152
newShape.dims[2] = param.headNum;
126153
newShape.dims[3] = oldShape.dims[1] / param.headNum;
127154
};
128155
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(1) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
129156
newShape.dimNum = 4; // dimNum: 4
130-
newShape.dims[0] = oldShape.dims[0];
157+
newShape.dims[0] = 1;
131158
newShape.dims[1] = oldShape.dims[0];
132159
newShape.dims[2] = param.headNum;
133160
newShape.dims[3] = oldShape.dims[1] / param.headNum;
134161
};
135162
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(2) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
136163
newShape.dimNum = 4; // dimNum: 4
137164
newShape.dims[0] = oldShape.dims[0];
138-
newShape.dims[1] = oldShape.dims[1];
165+
newShape.dims[1] = oldShape.dims[1]; // TODO: how to get seq_len
139166
newShape.dims[2] = param.headNum;
140167
newShape.dims[3] = oldShape.dims[2] / param.headNum;
141168
};
142169
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(3) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
143-
newShape.dimNum = 4; // dimNum: 4
144-
newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
145-
newShape.dims[1] = oldShape.dims[2];
146-
newShape.dims[2] = oldShape.dims[3];
147-
newShape.dims[3] = oldShape.dims[4];
170+
// 生成的是[1, max_batch_size, head_num, max_len, head_dim]
171+
// 加速库需要[layer, max_batch_size, max_len, head_size], 理论应有transpose完成,但读写都为加速库使用,故直接reshape规避
172+
newShape.dimNum = 4; // dimNum: 4
173+
newShape.dims[0] = 1;
174+
newShape.dims[1] = oldShape.dims[1];
175+
newShape.dims[2] = oldShape.dims[3];
176+
newShape.dims[3] = oldShape.dims[2] * oldShape.dims[4];
148177
};
149178
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(4) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
150-
newShape.dimNum = 4; // dimNum: 4
179+
// 生成的是[1, max_batch_size, head_num, max_len, head_dim]
180+
// 加速库需要[layer, max_batch_size, max_len, head_size], 理论应有transpose完成,但读写都为加速库使用,故直接reshape规避
181+
newShape.dimNum = 4; // dimNum: 4
182+
newShape.dims[0] = 1;
183+
newShape.dims[1] = oldShape.dims[1];
184+
newShape.dims[2] = oldShape.dims[3];
185+
newShape.dims[3] = oldShape.dims[2] * oldShape.dims[4];
186+
};
187+
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(5) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
188+
newShape.dimNum = 2; // dimNum: 4
189+
newShape.dims[0] = oldShape.dims[2];
190+
newShape.dims[1] = oldShape.dims[3];
191+
};
192+
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(6) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
193+
newShape.dimNum = 1; // dimNum: 1
194+
newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
195+
};
196+
selfAttentionKvCacheNode.inTensorReshapeFuncs.at(7) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
197+
newShape.dimNum = 1; // dimNum: 1
151198
newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
152-
newShape.dims[1] = oldShape.dims[2];
153-
newShape.dims[2] = oldShape.dims[3];
154-
newShape.dims[3] = oldShape.dims[4];
155199
};
156200

201+
// [1, 1, 512] * [512, 4096] -> [1, 1, 4096]
157202
atb::infer::LinearParallelParam selfOutLinearParallelParam;
158203
selfOutLinearParallelParam.transWeight = true;
159204
selfOutLinearParallelParam.rank = param.rank;
@@ -167,18 +212,18 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
167212
selfOutLinearParallelNode.inTensorIds = {INTERMIDATE_SELFOUT, IN_SELFOUTLINEARWEIGHT};
168213
selfOutLinearParallelNode.outTensorIds = {INTERMIDATE_SELFLINEAROUT};
169214

215+
// [bs * seq_len, hidden_size] + [1, 1, 4096]
170216
atb::infer::ElewiseParam selfResidualAddParam;
171217
selfResidualAddParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
172218
atb::CreateOperation(selfResidualAddParam, &selfResidualAddNode.operation);
173219
selfResidualAddNode.inTensorIds = {IN_HIDDENSTATES, INTERMIDATE_SELFLINEAROUT};
174220
selfResidualAddNode.outTensorIds = {INTERMIDATE_SELFRESIDUALADDOUT};
175221
selfResidualAddNode.inTensorReshapeFuncs.resize(selfResidualAddNode.inTensorIds.size());
176-
selfResidualAddNode.inTensorReshapeFuncs.at(1) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
177-
newShape.dimNum = 3; // dimNum: 3
178-
newShape.dims[0] = oldShape.dims[1];
179-
newShape.dims[1] = oldShape.dims[0];
180-
newShape.dims[2] = oldShape.dims[2];
181-
};
222+
// selfResidualAddNode.inTensorReshapeFuncs.at(1) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
223+
// newShape.dimNum = 2; // dimNum: 3
224+
// newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
225+
// newShape.dims[1] = oldShape.dims[2];
226+
// };
182227

183228
atb::infer::RmsNormParam selfNormParam;
184229
selfNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM;

backends/npu/custom_op/llama_layer/llama_fusion_parallel_operation.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ enum LlamaLayerFusionParallelTensorId {
3838
INTERMIDATE_MIXEDQ,
3939
INTERMIDATE_MIXEDK,
4040
INTERMIDATE_MIXEDV,
41+
INTERNAL_CAST_COS_SIN_TABLE,
4142
INTERMIDATE_CASTCOS,
4243
INTERMIDATE_CASTSIN,
4344
INTERMIDATE_POSITIONEMBEDQ,

0 commit comments

Comments
 (0)