Skip to content
Merged
29 changes: 15 additions & 14 deletions paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@

namespace {

inline int getSMVersion() {
int getSMVersion() {
int sm_version = 80;
#if defined(PADDLE_WITH_CUDA)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"platform::GetGPUComputeCapability is not "
"supported in CPU only version."));
#endif
return sm_version;
}
Expand Down Expand Up @@ -74,15 +78,15 @@ class FusedWeightOnlyLinearPattern
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "weight_only_int8";
});
// int arch = getSMVersion();
const auto &weight_quantize_arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> std::any {
return 80;

const auto &arch_attr =
res.Attr([&](const pir::drr::MatchContext &match_ctx) -> int {
return getSMVersion();
});

const auto &weight_quantize = res.Op(
"pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", weight_quantize_arch_attr}});
const auto &weight_quantize =
res.Op("pd_op.weight_quantize",
{{"algo", weight_only_int8_attr}, {"arch", arch_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});
Expand All @@ -92,12 +96,9 @@ class FusedWeightOnlyLinearPattern
return "int8";
});

const auto &weight_only_linear_arch_attr = res.Attr(
[&](const pir::drr::MatchContext &match_ctx) -> int { return 80; });
const auto &weight_only_linear =
res.Op("pd_op.weight_only_linear",
{{"weight_dtype", weight_dtype_attr},
{"arch", weight_only_linear_arch_attr}});
{{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
Expand All @@ -119,8 +120,8 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {

bool CanApplyOn(pir::Operation *op) const override {
int sm_vesion = getSMVersion();
if (sm_vesion != 70 && sm_vesion != 80 && sm_vesion != 86 &&
sm_vesion != 75) {
// TODO(Wanglongzhi2001): only support sm80 for now
if (sm_vesion != 80) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前develop版本已经支持了70 75 80 86这四个架构

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return false;
}
return op->num_regions() > 0;
Expand Down