@@ -1085,6 +1085,42 @@ struct CastOpTranscriber : public OpTranscriber {
10851085 }
10861086};
10871087
1088+ struct LeakyReLUOpTranscriber : public OpTranscriber {
1089+ pir::AttributeMap TranslateOpAttribute (
1090+ pir::IrContext* ctx,
1091+ const std::string& normalized_op_name,
1092+ const OpAttributeInfoList& op_attr_infos,
1093+ const OpDesc& op_desc) override {
1094+ auto & attribute_translator = AttributeTranslator::instance ();
1095+ auto & op_normalizer = OpNameNormalizer::instance ();
1096+ pir::AttributeMap attribute_map = {};
1097+
1098+ for (const auto & info : op_attr_infos) {
1099+ auto legacy_attr_name =
1100+ op_normalizer.GetLegacyAttrName (op_desc.Type (), info.name );
1101+ VLOG (10 ) << " [op: " << op_desc.Type ()
1102+ << " ][attr] from: " << legacy_attr_name << " to: " << info.name ;
1103+ if (op_desc.HasAttr (legacy_attr_name)) {
1104+ paddle::framework::Attribute legacy_attr =
1105+ op_desc.GetAttr (legacy_attr_name);
1106+ VLOG (10 ) << " attribute in " << op_desc.Type ()
1107+ << " name: " << legacy_attr_name << " " << legacy_attr.index ();
1108+ pir::Attribute new_attr =
1109+ attribute_translator (info.type_name , legacy_attr);
1110+ if (legacy_attr_name == " alpha" ) {
1111+ new_attr = pir::DoubleAttribute::get (
1112+ ctx,
1113+ static_cast <double >(
1114+ new_attr.dyn_cast <pir::FloatAttribute>().data ()));
1115+ }
1116+ attribute_map[info.name ] = new_attr;
1117+ }
1118+ }
1119+
1120+ return attribute_map;
1121+ }
1122+ };
1123+
10881124struct Conv2dOpTranscriber : public OpTranscriber {
10891125 void HandleNonexistentAttribute (pir::IrContext* ctx,
10901126 pir::AttributeMap* attribute_map,
@@ -3921,6 +3957,80 @@ struct SyncCommStreamOpTranscriber : public OpTranscriber {
39213957 }
39223958};
39233959
3960+ struct SoftPlusOpTranscriber : public OpTranscriber {
3961+ pir::AttributeMap TranslateOpAttribute (
3962+ pir::IrContext* ctx,
3963+ const std::string& normalized_op_name,
3964+ const OpAttributeInfoList& op_attr_infos,
3965+ const OpDesc& op_desc) override {
3966+ auto & attribute_translator = AttributeTranslator::instance ();
3967+ auto & op_normalizer = OpNameNormalizer::instance ();
3968+ pir::AttributeMap attribute_map = {};
3969+
3970+ for (const auto & info : op_attr_infos) {
3971+ auto legacy_attr_name =
3972+ op_normalizer.GetLegacyAttrName (op_desc.Type (), info.name );
3973+ VLOG (10 ) << " [op: " << op_desc.Type ()
3974+ << " ][attr] from: " << legacy_attr_name << " to: " << info.name ;
3975+ if (op_desc.HasAttr (legacy_attr_name)) {
3976+ paddle::framework::Attribute legacy_attr =
3977+ op_desc.GetAttr (legacy_attr_name);
3978+ VLOG (10 ) << " attribute in " << op_desc.Type ()
3979+ << " name: " << legacy_attr_name << " " << legacy_attr.index ();
3980+ pir::Attribute new_attr =
3981+ attribute_translator (info.type_name , legacy_attr);
3982+ if (legacy_attr_name == " beta" || legacy_attr_name == " threshold" ) {
3983+ new_attr = pir::DoubleAttribute::get (
3984+ ctx,
3985+ static_cast <double >(
3986+ new_attr.dyn_cast <pir::FloatAttribute>().data ()));
3987+ }
3988+ attribute_map[info.name ] = new_attr;
3989+ } else {
3990+ this ->HandleNonexistentAttribute (ctx, &attribute_map, info);
3991+ }
3992+ }
3993+ return attribute_map;
3994+ }
3995+ };
3996+
3997+ struct LogitOpTranscriber : public OpTranscriber {
3998+ pir::AttributeMap TranslateOpAttribute (
3999+ pir::IrContext* ctx,
4000+ const std::string& normalized_op_name,
4001+ const OpAttributeInfoList& op_attr_infos,
4002+ const OpDesc& op_desc) override {
4003+ auto & attribute_translator = AttributeTranslator::instance ();
4004+ auto & op_normalizer = OpNameNormalizer::instance ();
4005+ pir::AttributeMap attribute_map = {};
4006+
4007+ for (const auto & info : op_attr_infos) {
4008+ auto legacy_attr_name =
4009+ op_normalizer.GetLegacyAttrName (op_desc.Type (), info.name );
4010+ VLOG (10 ) << " [op: " << op_desc.Type ()
4011+ << " ][attr] from: " << legacy_attr_name << " to: " << info.name ;
4012+ if (op_desc.HasAttr (legacy_attr_name)) {
4013+ paddle::framework::Attribute legacy_attr =
4014+ op_desc.GetAttr (legacy_attr_name);
4015+ VLOG (10 ) << " attribute in " << op_desc.Type ()
4016+ << " name: " << legacy_attr_name << " " << legacy_attr.index ();
4017+ pir::Attribute new_attr =
4018+ attribute_translator (info.type_name , legacy_attr);
4019+ if (legacy_attr_name == " eps" ) {
4020+ new_attr = pir::DoubleAttribute::get (
4021+ ctx,
4022+ static_cast <double >(
4023+ new_attr.dyn_cast <pir::FloatAttribute>().data ()));
4024+ }
4025+ attribute_map[info.name ] = new_attr;
4026+ } else {
4027+ this ->HandleNonexistentAttribute (ctx, &attribute_map, info);
4028+ }
4029+ }
4030+ return attribute_map;
4031+ }
4032+ };
4033+
39244034OpTranslator::OpTranslator () {
39254035 pir::IrContext* ctx = pir::IrContext::Instance ();
39264036 ctx->GetOrRegisterDialect <paddle::dialect::OperatorDialect>();
@@ -3933,6 +4043,8 @@ OpTranslator::OpTranslator() {
39334043 special_handlers[" batch_norm" ] = BatchNormOpTranscriber ();
39344044 special_handlers[" range" ] = ArangeOpTranscriber ();
39354045 special_handlers[" cast" ] = CastOpTranscriber ();
4046+ special_handlers[" leaky_relu" ] = LeakyReLUOpTranscriber ();
4047+ special_handlers[" leaky_relu_grad" ] = LeakyReLUOpTranscriber ();
39364048 special_handlers[" conv2d" ] = Conv2dOpTranscriber ();
39374049 special_handlers[" conv3d" ] = Conv3dOpTranscriber ();
39384050 special_handlers[" cross_entropy_with_softmax" ] =
@@ -4033,5 +4145,9 @@ OpTranslator::OpTranslator() {
40334145 WithXShapeAndAxisGradOpTranscriber<dialect::UnsqueezeGradOp>();
40344146
40354147 special_handlers[" c_sync_comm_stream" ] = SyncCommStreamOpTranscriber ();
4148+ special_handlers[" softplus" ] = SoftPlusOpTranscriber ();
4149+ special_handlers[" softplus_grad" ] = SoftPlusOpTranscriber ();
4150+ special_handlers[" logit" ] = LogitOpTranscriber ();
4151+ special_handlers[" logit_grad" ] = LogitOpTranscriber ();
40364152}
40374153} // namespace paddle::translator
0 commit comments