Skip to content

Conversation

@zkh2016
Copy link
Contributor

@zkh2016 zkh2016 commented Aug 24, 2021

PR types

New features

PR changes

OPs

Describe

Fuse the elementwise_add, activatioin and dropout into one operator.

//before fusion
out1 = elementwise_add(src, bias)
out2 = activation(out1)
out3 = dropout(out2)
//after fusion
out = fused_dropout_act_bias(src, bias, activation_functor)

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这这文件定义的函数是不是被几个地方共用呢 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 公用的,在另外一个pr也提了,得先合一个

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件共用吗 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加些注释说明函数的功能吧

@zkh2016 zkh2016 force-pushed the fused_dropout_act_bias branch from d643509 to f3a365e Compare August 25, 2021 12:14
@zkh2016 zkh2016 force-pushed the fused_dropout_act_bias branch from f3a365e to f8f4e07 Compare August 30, 2021 12:09
@zkh2016 zkh2016 force-pushed the fused_dropout_act_bias branch from a21bc90 to 4dba815 Compare September 8, 2021 11:45
return static_cast<T>(casted_dout * (first + second));
}
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的激活函数,Relu和Gelu在math下面都有,可以直接复用吗,因为math下面实现的接口已经很统一了,复用的话这里应该就不需要再封装一遍?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, gelu的实现参考gelu_op的,和math下的稍有不同。可以直接传math下的functor。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同点在哪?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/gelu_op.h#L96
这个主要是参考gelu_op的实现,采用了两种计算方式,
一种近似计算和math的方式一样: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
另一种:gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))

namespace paddle {
namespace operators {

typedef platform::float16 fp16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码中不建议用这种缩写。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
// store result to global
platform::Store<T, VecSize>(dest_vec, &dst[r * cols + i]);
platform::Store<MaskType, VecSize>(mask_vec, &mask[r * cols + i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉与上一个PR中存在很多相同的代码,建议通过封装函数进行代码复用。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改,提交在下一个PR

const platform::CUDADeviceContext &ctx) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
PADDLE_ENFORCE_CUDA_SUCCESS(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然这个Memset被调用了这么多次,也可以实现个函数封装下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是可以调用Kernel Primitives API函数?

int bias_id = blockIdx.x * blockDim.x * VecSize + x;
if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) {
dbias[bias_id] = sum;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L279 - L307也是和上个PR中一样的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


{
out.Resize({rows, cols});
out.mutable_data<T>(place);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接调用out.mutable_data<T>(dims, place);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

test.Run();
test.CheckOut(default_diff);
if (!is_fp16) {
// test fp16, For inference, check_grad is not required. ref:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要支持fp16训练。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

return static_cast<T>(casted_dout * (first + second));
}
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同点在哪?

StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T args[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前这种写法没有必要定义T args[2];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在下一个PR中修改

#pragma unroll
for (int i = 0; i < VecSize; i++) {
T val;
T args[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前这种写法没有必要定义T args[2];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在下一个PR中修改

static void BaseTest(const bool is_fp16 = false) {
const int rows = 16;
std::vector<int> cols_list = {16, 17};
bool has_bias[2] = {true, false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L271、L272这两行多余的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在下一个PR中修改

paddle::operators::GeluGradFunctor<double>>();
}

// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skip_check_grad_ci(reason="For inference, check_grad is not required.")
class TestDropoutOp5(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_test': True}
self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
def test_check_output(self):
self.check_output()

dropout单测中也存在fp32没有检查grad的,只代表这个测试case是为了测试推理的正确性。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经下一个PR中加了grad的单测了

@Xreki Xreki merged commit cee7043 into PaddlePaddle:develop Sep 16, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
@zkh2016 zkh2016 deleted the fused_dropout_act_bias branch August 19, 2022 04:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants