Skip to content

Commit 4f0e8d4

Browse files
committed
fix bug
1 parent 37b799d commit 4f0e8d4

3 files changed

Lines changed: 3 additions & 3 deletions

File tree

paddle/fluid/pir/dialect/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,13 @@ cc_library(
314314

315315
#Note(risemeup1):compile some *.cc files which depend on primitive_vjp_experimental into op_dialect_vjp.a/lib
316316
set(op_decomp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp.cc)
317-
set(op_decomp_vjp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp_vjp.cc)
317+
# set(op_decomp_vjp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp_vjp.cc)
318318
set(op_dialect_vjp_srcs
319319
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp.cc
320320
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp_vjp.cc
321321
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc
322322
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc
323323
${op_decomp_source_file}
324-
${op_decomp_vjp_source_file}
325324
${op_vjp_source_file}
326325
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc)
327326

paddle/fluid/pir/dialect/op_generator/op_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
#include <vector>
8989
9090
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
91+
#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h"
9192
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
9293
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
9394
#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h"

paddle/fluid/pir/dialect/operator/ir/manual_op_decomp_vjp.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace paddle {
3131
namespace dialect {
3232
using IntArray = paddle::experimental::IntArray;
3333

34-
std::vector<std::vector<pir::Value>> AddGradOp::Decomp(pir::Operation* op) {
34+
std::vector<std::vector<pir::Value>> AddGradOp::DecompVjp(pir::Operation* op) {
3535
VLOG(4) << "Decomp call add_grad's decomp interface begin";
3636

3737
AddGradOp op_obj = op->dyn_cast<AddGradOp>();

0 commit comments

Comments
 (0)