Skip to content

Commit 33253f2

Browse files
authored
Merge branch 'PaddlePaddle:develop' into pir7
2 parents 8fca8f9 + ee3d2fc commit 33253f2

File tree

934 files changed

+36313
-10272
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

934 files changed

+36313
-10272
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
108108
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
109109
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
110110
paddle/fluid/pir/dialect/operator/ir/pd_op.*
111+
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
112+
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
111113
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
112114
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
113115
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
5454
option(WITH_XPU_KP "Compile PaddlePaddle with BAIDU XPU compiler " OFF)
5555
option(WITH_XPU_XFT "Compile PaddlePaddle with BAIDU XPU-XFT" OFF)
5656
option(WITH_XPU_PLUGIN "Compile PaddlePaddle with BAIDU XPU plugin" OFF)
57-
option(WITH_XPU_XHPC "Compile PaddlePaddle with BAIDU XPU-HPC library" OFF)
57+
option(WITH_XPU_XHPC "Compile PaddlePaddle with BAIDU XPU-HPC library"
58+
${WITH_XPU})
5859
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
5960
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
6061
option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF)

cmake/external/xpu.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ set(XPU_XBLAS_LIB_NAME "libxpu_blas.so")
2626
set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")
2727

2828
if(NOT DEFINED XPU_BASE_DATE)
29-
set(XPU_BASE_DATE "20231203")
29+
set(XPU_BASE_DATE "20231218")
3030
endif()
3131
if(NOT DEFINED XPU_XHPC_BASE_DATE)
32-
set(XPU_XHPC_BASE_DATE "20231215")
32+
set(XPU_XHPC_BASE_DATE "20231229")
3333
endif()
34-
set(XPU_XCCL_BASE_VERSION "1.1.7.1")
34+
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
3535
if(NOT DEFINED XPU_XFT_BASE_VERSION)
3636
set(XPU_XFT_BASE_VERSION "20230602")
3737
endif()

cmake/inference_lib.cmake

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ copy_part_of_thrid_party(inference_lib_dist ${PADDLE_INFERENCE_INSTALL_DIR})
237237

238238
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
239239

240+
if(WIN32)
241+
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/common.*)
242+
else()
243+
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
244+
endif()
245+
copy(
246+
inference_lib_dist
247+
SRCS ${paddle_common_lib}
248+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
249+
240250
if(WIN32)
241251
if(WITH_STATIC_LIB)
242252
set(paddle_inference_lib
@@ -268,11 +278,6 @@ else()
268278
SRCS ${paddle_phi_lib}
269279
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
270280
endif()
271-
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
272-
copy(
273-
inference_lib_dist
274-
SRCS ${paddle_common_lib}
275-
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
276281
endif()
277282

278283
copy(
@@ -323,10 +328,13 @@ copy(
323328
inference_lib_dist
324329
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h
325330
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
331+
326332
copy(
327333
inference_lib_dist
328-
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h
329-
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
334+
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h
335+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/
336+
)
337+
330338
copy(
331339
inference_lib_dist
332340
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h

paddle/cinn/backends/codegen_cuda_host.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,23 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
198198
[](auto& arg) { return std::addressof(arg); });
199199
// @}
200200

201+
// Set local scope table
202+
CHECK_EQ(ll_function_args.size(), func->args.size());
203+
for (int i = 0; i < ll_function_args.size(); ++i) {
204+
SetVar(func->args[i].name(), ll_function_args[i]);
205+
}
201206
llvm::BasicBlock* entry = llvm::BasicBlock::Create(
202207
/*Context=*/b_->getContext(),
203208
/*Name=*/"entry",
204209
/*Parent=*/f_,
205210
/*InsertBefore=*/nullptr);
206211
b_->SetInsertPoint(entry);
207212
CodeGenLLVM::Visit(&func->body);
213+
214+
// Reset local scope table
215+
for (const ir::Argument& func_arg : func->args) {
216+
symbol_table_->Erase(func_arg.name());
217+
}
208218
RetVoid();
209219

210220
return f_;

paddle/cinn/backends/codegen_cuda_host.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
5353
} else if (op->name == runtime::intrinsic::call_cuda_kernel) {
5454
return LowerCUDAKernelCall(op);
5555
} else {
56-
CINN_NOT_IMPLEMENTED;
56+
return CodeGenLLVM::Visit(op);
5757
}
5858
}
5959

paddle/cinn/backends/codegen_cuda_util.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace backends {
3131
#define KERNEL_ARGS "kernel_args"
3232
#define KERNEL_ARGS_NUM "kernel_args_num"
3333
#define KERNEL_STREAM "kernel_stream"
34+
#define TENSOR_SHAPE_ARGS "tensor_shape_args"
3435

3536
/**
3637
* Split a CINN Module into two separate modules, one cantains the host
@@ -150,7 +151,8 @@ struct CollectBucketStrategyHostFunctionVisitor
150151
: CollectHostFunctionVisitor(module_name),
151152
kernel_args_(KERNEL_ARGS, type_of<void*>()),
152153
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
153-
kernel_stream_(KERNEL_STREAM, type_of<void*>()) {}
154+
kernel_stream_(KERNEL_STREAM, type_of<void*>()),
155+
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t**>()) {}
154156

155157
std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
156158
ir::IRMutator<>::Visit(expr, expr);
@@ -181,6 +183,25 @@ struct CollectBucketStrategyHostFunctionVisitor
181183
{});
182184
host_module_builder.AddFunctionWithoutOptim(
183185
host_func.as_lowered_func_ref());
186+
187+
// Parse LoweredFunc to infer output tensor's shape
188+
std::vector<ir::Expr> infer_shape_func_body_stmts(arg_defs_);
189+
infer_shape_func_body_stmts.insert(
190+
infer_shape_func_body_stmts.end(),
191+
op->infer_shape_func.as_lowered_func()->body);
192+
193+
std::vector<ir::Argument> infer_shape_arguments = {
194+
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
195+
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
196+
ir::Argument(tensor_shape_args_, ir::Argument::IO::kOutput)};
197+
198+
ir::Expr host_infer_shape_func =
199+
ir::_LoweredFunc_::Make(op->infer_shape_func.as_lowered_func()->name,
200+
infer_shape_arguments,
201+
ir::Block::Make(infer_shape_func_body_stmts),
202+
{});
203+
host_module_builder.AddFunctionWithoutOptim(
204+
host_infer_shape_func.as_lowered_func_ref());
184205
}
185206

186207
void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);
@@ -199,6 +220,7 @@ struct CollectBucketStrategyHostFunctionVisitor
199220
ir::Var kernel_args_;
200221
ir::Var kernel_args_num_;
201222
ir::Var kernel_stream_;
223+
ir::Var tensor_shape_args_;
202224
};
203225

204226
} // namespace detail

paddle/cinn/backends/llvm/codegen_llvm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
818818
// TODO(fc500110) hard coding
819819
if (LLVM_WillVarLowerAsPointer(op->name)) {
820820
result = value;
821-
} else if (value->getType()->isPointerTy()) {
821+
} else if (value->getType()->isPointerTy() &&
822+
!value->getType()->getPointerElementType()->isPointerTy()) {
822823
result = Load(value, op->name + "_load");
823824
} else {
824825
result = value;

paddle/cinn/common/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ gather_srcs(
2222
python_interpreter_guard.cc
2323
nvgpu_dev_info.cc
2424
integer_set.cc
25-
dim_expr_simplify.cc)
25+
dim_expr_simplify.cc
26+
dim_expr_converter.cc
27+
broadcast_tree.cc
28+
dim_expr_util.cc)
2629

2730
cinn_cc_test(test_equation_graph_topo_walker SRCS
2831
equation_graph_topo_walker_test.cc DEPS gtest glog)
@@ -47,6 +50,10 @@ if(WITH_CUDA)
4750
gtest glog)
4851
endif()
4952
if(NOT CINN_ONLY)
53+
cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore)
5054
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
5155
cinncore)
56+
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
57+
cinncore)
58+
cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore)
5259
endif()

0 commit comments

Comments
 (0)