2121
2222static const uint64_t IN_TENSOR_COUNT = 14 ;
2323static const uint64_t OUT_TENSOR_COUNT = 1 ;
24- static const uint64_t INTERMEDIATE_TENSOR_COUNT = 16 ;
25- static const uint64_t NODE_COUNT = 12 ;
24+ static const uint64_t INTERMEDIATE_TENSOR_COUNT = 17 ;
25+ static const uint64_t NODE_COUNT = 13 ;
2626
2727atb::Status LlamaLayerFusionParallelOperation (const LlamaLayerFusionParallelParam ¶m,
2828 atb::Operation **operation)
@@ -36,6 +36,7 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
3636 size_t nodeId = 0 ;
3737 atb::Node &inputNormNode = opGraph.nodes .at (nodeId++);
3838 atb::Node &mixdQKVLinearNode = opGraph.nodes .at (nodeId++);
39+ atb::Node &castInNode = opGraph.nodes .at (nodeId++);
3940 atb::Node &cosSinSplitNode = opGraph.nodes .at (nodeId++);
4041 atb::Node &ropeNode = opGraph.nodes .at (nodeId++);
4142 atb::Node &cacheKVSplitNode = opGraph.nodes .at (nodeId++);
@@ -47,24 +48,45 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
4748 atb::Node &mlpLinearParallelNode = opGraph.nodes .at (nodeId++);
4849 atb::Node &mlpResidualAddNode = opGraph.nodes .at (nodeId++);
4950
51+ // [bs, seq_len, hidden_size]
5052 atb::infer::RmsNormParam inputNormParam;
5153 inputNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM;
5254 inputNormParam.normParam .epsilon = param.rmsNormEps ;
5355 atb::CreateOperation (inputNormParam, &inputNormNode.operation );
5456 inputNormNode.inTensorIds = {IN_HIDDENSTATES, IN_NORMWEIGHT};
5557 inputNormNode.outTensorIds = {INTERMIDATE_INPUTNORMOUT};
58+ inputNormNode.inTensorReshapeFuncs .resize (inputNormNode.inTensorIds .size ());
59+ inputNormNode.inTensorReshapeFuncs .at (0 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
60+ if (oldShape.dimNum == 3 ) {
61+ newShape = oldShape;
62+ } else if (oldShape.dimNum == 2 ) {
63+ newShape.dimNum = 3 ; // 增量阶段
64+ newShape.dims [0 ] = oldShape.dims [0 ];
65+ newShape.dims [1 ] = 1 ;
66+ newShape.dims [2 ] = oldShape.dims [1 ];
67+ }
68+ };
5669
70+ // [bs, seq_len, hidden_size] * [3 * hidden_size / card_num, hidden_size] -> [bs,seq_len, hidden_size / card_num]
5771 MultiLayerLinearParam multiLayerLinearParam;
5872 multiLayerLinearParam.transpose = param.transpose ;
5973 CreateLlamaMultiLayerLinearOperation (multiLayerLinearParam, &mixdQKVLinearNode.operation );
6074 mixdQKVLinearNode.inTensorIds = {INTERMIDATE_INPUTNORMOUT, IN_QKVMIXDWEIGHT};
6175 mixdQKVLinearNode.outTensorIds = {INTERMIDATE_MIXEDQ, INTERMIDATE_MIXEDK, INTERMIDATE_MIXEDV};
6276
77+ atb::infer::ElewiseParam castParam;
78+ castParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_CAST;
79+ CreateOperation (castParam, &castInNode.operation );
80+ castInNode.inTensorIds = {IN_COS_SIN_TABLE};
81+ castInNode.outTensorIds = {INTERNAL_CAST_COS_SIN_TABLE};
82+
83+ // [2, head_dim, 1, seq_len, 1] ? [1, head_dim, 1, seq_len, 1]
6384 atb::infer::SplitParam splitParam = {0 , 2 };
6485 atb::CreateOperation (splitParam, &cosSinSplitNode.operation );
65- cosSinSplitNode.inTensorIds = {IN_COS_SIN_TABLE };
86+ cosSinSplitNode.inTensorIds = {INTERNAL_CAST_COS_SIN_TABLE };
6687 cosSinSplitNode.outTensorIds = {INTERMIDATE_CASTCOS, INTERMIDATE_CASTSIN};
6788
89+ // 全量:
6890 llamaPositionEmbedding1DSplitFusionParam positionEmbedding1dFusionParam;
6991 positionEmbedding1dFusionParam.headNum = param.headNum ;
7092 positionEmbedding1dFusionParam.rotaryCoeff = param.rotaryCoeff ;
@@ -96,12 +118,17 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
96118 newShape.dims [2 ] = oldShape.dims [3 ];
97119 newShape.dims [3 ] = oldShape.dims [4 ];
98120 };
121+ ropeNode.inTensorReshapeFuncs .at (5 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
122+ newShape.dimNum = 1 ; // dimNum: 4
123+ newShape.dims [0 ] = oldShape.dims [0 ] * oldShape.dims [1 ];
124+ };
99125
126+ // [2, 1, head_num / card_num, max_length, head_dim]
100127 atb::infer::SplitParam splitKVParam = {0 , 2 };
101128 atb::CreateOperation (splitKVParam, &cacheKVSplitNode.operation );
102129 cacheKVSplitNode.inTensorIds = {IN_CACHE_KV};
103130 cacheKVSplitNode.outTensorIds = {INTERMIDATE_CACHEK, INTERMIDATE_CACHEV};
104-
131+ // 全量 [1, 1, 4, 128] [1, 1, 4, 128] [1, 1, 4, 128] [1, 4, 2048, 128] [1, 4, 2048, 128] [1, 4, 2048, 2048] [1, 1] [1, 1] [1]
105132 atb::infer::SelfAttentionParam selfAttentionKvCacheParam;
106133 selfAttentionKvCacheParam.headDim = param.headDim ;
107134 selfAttentionKvCacheParam.headNum = param.headNum ;
@@ -120,40 +147,58 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
120147 selfAttentionKvCacheNode.inTensorReshapeFuncs .resize (selfAttentionKvCacheNode.inTensorIds .size ());
121148 selfAttentionKvCacheNode.inTensorReshapeFuncs .at (0 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
122149 newShape.dimNum = 4 ; // dimNum: 4
123- newShape.dims [0 ] = oldShape. dims [ 0 ] ;
150+ newShape.dims [0 ] = 1 ;
124151 newShape.dims [1 ] = oldShape.dims [0 ];
125152 newShape.dims [2 ] = param.headNum ;
126153 newShape.dims [3 ] = oldShape.dims [1 ] / param.headNum ;
127154 };
128155 selfAttentionKvCacheNode.inTensorReshapeFuncs .at (1 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
129156 newShape.dimNum = 4 ; // dimNum: 4
130- newShape.dims [0 ] = oldShape. dims [ 0 ] ;
157+ newShape.dims [0 ] = 1 ;
131158 newShape.dims [1 ] = oldShape.dims [0 ];
132159 newShape.dims [2 ] = param.headNum ;
133160 newShape.dims [3 ] = oldShape.dims [1 ] / param.headNum ;
134161 };
135162 selfAttentionKvCacheNode.inTensorReshapeFuncs .at (2 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
136163 newShape.dimNum = 4 ; // dimNum: 4
137164 newShape.dims [0 ] = oldShape.dims [0 ];
138- newShape.dims [1 ] = oldShape.dims [1 ];
165+ newShape.dims [1 ] = oldShape.dims [1 ]; // TODO: how to get seq_len
139166 newShape.dims [2 ] = param.headNum ;
140167 newShape.dims [3 ] = oldShape.dims [2 ] / param.headNum ;
141168 };
142169 selfAttentionKvCacheNode.inTensorReshapeFuncs .at (3 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
143- newShape.dimNum = 4 ; // dimNum: 4
144- newShape.dims [0 ] = oldShape.dims [0 ] * oldShape.dims [1 ];
145- newShape.dims [1 ] = oldShape.dims [2 ];
146- newShape.dims [2 ] = oldShape.dims [3 ];
147- newShape.dims [3 ] = oldShape.dims [4 ];
170+ // 生成的是[1, max_batch_size, head_num, max_len, head_dim]
171+ // 加速库需要[layer, max_batch_size, max_len, head_size], 理论应有transpose完成,但读写都为加速库使用,故直接reshape规避
172+ newShape.dimNum = 4 ; // dimNum: 4
173+ newShape.dims [0 ] = 1 ;
174+ newShape.dims [1 ] = oldShape.dims [1 ];
175+ newShape.dims [2 ] = oldShape.dims [3 ];
176+ newShape.dims [3 ] = oldShape.dims [2 ] * oldShape.dims [4 ];
148177 };
149178 selfAttentionKvCacheNode.inTensorReshapeFuncs .at (4 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
150- newShape.dimNum = 4 ; // dimNum: 4
179+ // 生成的是[1, max_batch_size, head_num, max_len, head_dim]
180+ // 加速库需要[layer, max_batch_size, max_len, head_size], 理论应有transpose完成,但读写都为加速库使用,故直接reshape规避
181+ newShape.dimNum = 4 ; // dimNum: 4
182+ newShape.dims [0 ] = 1 ;
183+ newShape.dims [1 ] = oldShape.dims [1 ];
184+ newShape.dims [2 ] = oldShape.dims [3 ];
185+ newShape.dims [3 ] = oldShape.dims [2 ] * oldShape.dims [4 ];
186+ };
187+ selfAttentionKvCacheNode.inTensorReshapeFuncs .at (5 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
188+ newShape.dimNum = 2 ; // dimNum: 4
189+ newShape.dims [0 ] = oldShape.dims [2 ];
190+ newShape.dims [1 ] = oldShape.dims [3 ];
191+ };
192+ selfAttentionKvCacheNode.inTensorReshapeFuncs .at (6 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
193+ newShape.dimNum = 1 ; // dimNum: 1
194+ newShape.dims [0 ] = oldShape.dims [0 ] * oldShape.dims [1 ];
195+ };
196+ selfAttentionKvCacheNode.inTensorReshapeFuncs .at (7 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
197+ newShape.dimNum = 1 ; // dimNum: 1
151198 newShape.dims [0 ] = oldShape.dims [0 ] * oldShape.dims [1 ];
152- newShape.dims [1 ] = oldShape.dims [2 ];
153- newShape.dims [2 ] = oldShape.dims [3 ];
154- newShape.dims [3 ] = oldShape.dims [4 ];
155199 };
156200
201+ // [1, 1, 512] * [512, 4096] -> [1, 1, 4096]
157202 atb::infer::LinearParallelParam selfOutLinearParallelParam;
158203 selfOutLinearParallelParam.transWeight = true ;
159204 selfOutLinearParallelParam.rank = param.rank ;
@@ -167,18 +212,18 @@ atb::Status LlamaLayerFusionParallelOperation(const LlamaLayerFusionParallelPara
167212 selfOutLinearParallelNode.inTensorIds = {INTERMIDATE_SELFOUT, IN_SELFOUTLINEARWEIGHT};
168213 selfOutLinearParallelNode.outTensorIds = {INTERMIDATE_SELFLINEAROUT};
169214
215+ // [bs * seq_len, hidden_size] + [1, 1, 4096]
170216 atb::infer::ElewiseParam selfResidualAddParam;
171217 selfResidualAddParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
172218 atb::CreateOperation (selfResidualAddParam, &selfResidualAddNode.operation );
173219 selfResidualAddNode.inTensorIds = {IN_HIDDENSTATES, INTERMIDATE_SELFLINEAROUT};
174220 selfResidualAddNode.outTensorIds = {INTERMIDATE_SELFRESIDUALADDOUT};
175221 selfResidualAddNode.inTensorReshapeFuncs .resize (selfResidualAddNode.inTensorIds .size ());
176- selfResidualAddNode.inTensorReshapeFuncs .at (1 ) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
177- newShape.dimNum = 3 ; // dimNum: 3
178- newShape.dims [0 ] = oldShape.dims [1 ];
179- newShape.dims [1 ] = oldShape.dims [0 ];
180- newShape.dims [2 ] = oldShape.dims [2 ];
181- };
222+ // selfResidualAddNode.inTensorReshapeFuncs.at(1) = [=](const atb::Dims &oldShape, atb::Dims &newShape) {
223+ // newShape.dimNum = 2; // dimNum: 3
224+ // newShape.dims[0] = oldShape.dims[0] * oldShape.dims[1];
225+ // newShape.dims[1] = oldShape.dims[2];
226+ // };
182227
183228 atb::infer::RmsNormParam selfNormParam;
184229 selfNormParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM;
0 commit comments