@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include < memory>
16-
1715#include " paddle/fluid/framework/ir/op_compat_sensible_pass.h"
16+ #include < memory>
17+ #include < mutex>
18+ #include < unordered_map>
19+ #include " paddle/fluid/framework/op_def_api.h"
1820#include " paddle/fluid/framework/op_info.h"
21+
1922namespace paddle {
2023namespace framework {
2124namespace ir {
@@ -50,18 +53,17 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
5053 return *this ;
5154}
5255
53- // ! Todo: append the definition.
5456AttrCompat& AttrCompat::IsLeftDefault () {
5557 const std::string& op_name = op_compat_->Name ();
5658 if (!OpInfoMap::Instance ().Has (op_name)) {
57- VLOG ( 3 ) << " Op (" << op_name << " ) is not registered!" ;
59+ LOG (WARNING ) << " Op (" << op_name << " ) is not registered!" ;
5860 conditions_.emplace_back ([](const Attribute& attr) { return false ; });
5961 return *this ;
6062 }
6163 const OpInfo& op_info = OpInfoMap::Instance ().Get (op_name);
6264 const AttributeMap attrs = op_info.Checker ()->GetAttrsDefaultValuesMap ();
6365 if (attrs.find (attr_name_) == attrs.end ()) {
64- VLOG ( 3 ) << " Op (" << op_name << " ) has no default attr:" << attr_name_;
66+ LOG (WARNING ) << " Op (" << op_name << " ) has no default attr:" << attr_name_;
6567 conditions_.emplace_back ([](const Attribute& attr) { return false ; });
6668 } else {
6769 Attribute default_attr = attrs.at (attr_name_);
@@ -77,6 +79,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
7779 return true ;
7880 }
7981 if (!op_desc.HasAttr (attr_name_)) {
82+ if (!optional_) {
83+ LOG (WARNING) << " The non-optional Attr(" << attr_name_ << " ) of Op ("
84+ << op_compat_->Name () << " ) not find ! " ;
85+ }
8086 return optional_;
8187 }
8288 const Attribute attr = op_desc.GetAttr (attr_name_);
@@ -149,19 +155,35 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
149155}
150156
151157bool OpCompat::Judge (const OpDesc& op_desc) {
158+ if (is_first_judge_) {
159+ is_first_judge_ = false ;
160+ const proto::OpDef& op_def = GetOpDef (op_name_);
161+ if (op_def.has_extra ()) {
162+ for (const proto::OpDef_AttrDef& attr : op_def.extra ().attrs ()) {
163+ extra_attrs_.emplace (attr.name ());
164+ }
165+ }
166+ }
167+
152168 for (auto & attr_map : op_desc.GetAttrMap ()) {
153169 if (attr_compats_.find (attr_map.first ) == attr_compats_.end ()) {
170+ if (extra_attrs_.find (attr_map.first ) != extra_attrs_.end ()) {
171+ continue ;
172+ }
154173 if (!AttrCompat (attr_map.first , this ).IsLeftDefault ()(op_desc)) {
155- VLOG (3 ) << " The Attr(" << attr_map.first << " ) of Op (" << op_name_
156- << " ) not reigistered in OpCompat, not equal to default value!" ;
174+ LOG (WARNING)
175+ << " The Attr(" << attr_map.first << " ) of Op (" << op_name_
176+ << " ) not reigistered in OpCompat, not in extra attribute, not "
177+ " equal to default value!" ;
157178 return false ;
158179 }
159180 }
160181 }
182+
161183 for (auto & attr_compat : attr_compats_) {
162184 if (!attr_compat.second (op_desc)) {
163- VLOG ( 3 ) << " Check the Attr(" << attr_compat.first << " ) of Op("
164- << op_name_ << " ) failed!" ;
185+ LOG (WARNING ) << " Check the Attr(" << attr_compat.first << " ) of Op("
186+ << op_name_ << " ) failed!" ;
165187 return false ;
166188 }
167189 }
@@ -170,23 +192,24 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
170192 for (auto & input_desc : inputs_map) {
171193 if (input_compats_.find (input_desc.first ) == input_compats_.end ()) {
172194 if (!input_desc.second .empty ()) {
173- VLOG ( 3 ) << " The Input (" << input_desc.first << " ) of Operator ("
174- << op_name_ << " ) not reigistered in OpCompat!" ;
195+ LOG (WARNING ) << " The Input (" << input_desc.first << " ) of Operator ("
196+ << op_name_ << " ) not reigistered in OpCompat!" ;
175197 return false ;
176198 }
177199 }
178200 }
179201 for (auto & input_val : input_compats_) {
180202 if (inputs_map.find (input_val.first ) == inputs_map.end ()) {
181203 if (!input_val.second .Optional ()) {
182- VLOG (3 ) << " The No optional Input (" << input_val.first
183- << " ) of Operator (" << op_name_ << " ) not find in op_desc!" ;
204+ LOG (WARNING) << " The No optional Input (" << input_val.first
205+ << " ) of Operator (" << op_name_
206+ << " ) not find in op_desc!" ;
184207 return false ;
185208 }
186209 } else {
187210 if (!input_val.second (inputs_map.at (input_val.first ))) {
188- VLOG ( 3 ) << " The Input (" << input_val.first << " ) of Operator ("
189- << op_name_ << " ) compat check failed!" ;
211+ LOG (WARNING ) << " The Input (" << input_val.first << " ) of Operator ("
212+ << op_name_ << " ) compat check failed!" ;
190213 return false ;
191214 }
192215 }
@@ -196,23 +219,24 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
196219 for (auto & output_desc : outputs_map) {
197220 if (output_compats_.find (output_desc.first ) == output_compats_.end ()) {
198221 if (!output_desc.second .empty ()) {
199- VLOG ( 3 ) << " The Output (" << output_desc.first << " ) of Operator ("
200- << op_name_ << " ) not reigistered in OpCompat!" ;
222+ LOG (WARNING ) << " The Output (" << output_desc.first << " ) of Operator ("
223+ << op_name_ << " ) not reigistered in OpCompat!" ;
201224 return false ;
202225 }
203226 }
204227 }
205228 for (auto & output_val : output_compats_) {
206229 if (outputs_map.find (output_val.first ) == outputs_map.end ()) {
207230 if (!output_val.second .Optional ()) {
208- VLOG (3 ) << " The No optional Output (" << output_val.first
209- << " ) of Operator (" << op_name_ << " ) not find in op_desc!" ;
231+ LOG (WARNING) << " The No optional Output (" << output_val.first
232+ << " ) of Operator (" << op_name_
233+ << " ) not find in op_desc!" ;
210234 return false ;
211235 }
212236 } else {
213237 if (!output_val.second (outputs_map.at (output_val.first ))) {
214- VLOG ( 3 ) << " The Output (" << output_val.first << " ) of Operator ("
215- << op_name_ << " ) compat check failed!" ;
238+ LOG (WARNING ) << " The Output (" << output_val.first << " ) of Operator ("
239+ << op_name_ << " ) compat check failed!" ;
216240 return false ;
217241 }
218242 }
0 commit comments