Skip to content

Commit 27f18a3

Browse files
GGBond8488Your Name
authored andcommitted
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_varmap_save
2 parents 1e6a7d3 + 5bb661d commit 27f18a3

767 files changed

Lines changed: 30256 additions & 8253 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
108108
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
109109
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
110110
paddle/fluid/pir/dialect/operator/ir/pd_op.*
111+
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
112+
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
111113
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
112114
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
113115
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*

cmake/external/xpu.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ set(XPU_XBLAS_LIB_NAME "libxpu_blas.so")
2626
set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")
2727

2828
if(NOT DEFINED XPU_BASE_DATE)
29-
set(XPU_BASE_DATE "20231203")
29+
set(XPU_BASE_DATE "20231218")
3030
endif()
3131
if(NOT DEFINED XPU_XHPC_BASE_DATE)
32-
set(XPU_XHPC_BASE_DATE "20231226")
32+
set(XPU_XHPC_BASE_DATE "20231229")
3333
endif()
3434
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
3535
if(NOT DEFINED XPU_XFT_BASE_VERSION)

cmake/inference_lib.cmake

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,19 @@ copy(
328328
inference_lib_dist
329329
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h
330330
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
331+
331332
copy(
332333
inference_lib_dist
333-
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h
334-
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
334+
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h
335+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/
336+
)
337+
338+
copy(
339+
inference_lib_dist
340+
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h
341+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/
342+
)
343+
335344
copy(
336345
inference_lib_dist
337346
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h

paddle/cinn/backends/codegen_cuda_host.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,23 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) {
198198
[](auto& arg) { return std::addressof(arg); });
199199
// @}
200200

201+
// Set local scope table
202+
CHECK_EQ(ll_function_args.size(), func->args.size());
203+
for (int i = 0; i < ll_function_args.size(); ++i) {
204+
SetVar(func->args[i].name(), ll_function_args[i]);
205+
}
201206
llvm::BasicBlock* entry = llvm::BasicBlock::Create(
202207
/*Context=*/b_->getContext(),
203208
/*Name=*/"entry",
204209
/*Parent=*/f_,
205210
/*InsertBefore=*/nullptr);
206211
b_->SetInsertPoint(entry);
207212
CodeGenLLVM::Visit(&func->body);
213+
214+
// Reset local scope table
215+
for (const ir::Argument& func_arg : func->args) {
216+
symbol_table_->Erase(func_arg.name());
217+
}
208218
RetVoid();
209219

210220
return f_;

paddle/cinn/backends/codegen_cuda_host.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
5353
} else if (op->name == runtime::intrinsic::call_cuda_kernel) {
5454
return LowerCUDAKernelCall(op);
5555
} else {
56-
CINN_NOT_IMPLEMENTED;
56+
return CodeGenLLVM::Visit(op);
5757
}
5858
}
5959

paddle/cinn/backends/codegen_cuda_util.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace backends {
3131
#define KERNEL_ARGS "kernel_args"
3232
#define KERNEL_ARGS_NUM "kernel_args_num"
3333
#define KERNEL_STREAM "kernel_stream"
34+
#define TENSOR_SHAPE_ARGS "tensor_shape_args"
3435

3536
/**
3637
* Split a CINN Module into two separate modules, one cantains the host
@@ -150,7 +151,8 @@ struct CollectBucketStrategyHostFunctionVisitor
150151
: CollectHostFunctionVisitor(module_name),
151152
kernel_args_(KERNEL_ARGS, type_of<void*>()),
152153
kernel_args_num_(KERNEL_ARGS_NUM, type_of<int>()),
153-
kernel_stream_(KERNEL_STREAM, type_of<void*>()) {}
154+
kernel_stream_(KERNEL_STREAM, type_of<void*>()),
155+
tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of<int32_t**>()) {}
154156

155157
std::tuple<ir::Module, ir::Module> operator()(Expr* expr) {
156158
ir::IRMutator<>::Visit(expr, expr);
@@ -181,6 +183,25 @@ struct CollectBucketStrategyHostFunctionVisitor
181183
{});
182184
host_module_builder.AddFunctionWithoutOptim(
183185
host_func.as_lowered_func_ref());
186+
187+
// Parse LoweredFunc to infer output tensor's shape
188+
std::vector<ir::Expr> infer_shape_func_body_stmts(arg_defs_);
189+
infer_shape_func_body_stmts.insert(
190+
infer_shape_func_body_stmts.end(),
191+
op->infer_shape_func.as_lowered_func()->body);
192+
193+
std::vector<ir::Argument> infer_shape_arguments = {
194+
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
195+
ir::Argument(kernel_args_num_, ir::Argument::IO::kInput),
196+
ir::Argument(tensor_shape_args_, ir::Argument::IO::kOutput)};
197+
198+
ir::Expr host_infer_shape_func =
199+
ir::_LoweredFunc_::Make(op->infer_shape_func.as_lowered_func()->name,
200+
infer_shape_arguments,
201+
ir::Block::Make(infer_shape_func_body_stmts),
202+
{});
203+
host_module_builder.AddFunctionWithoutOptim(
204+
host_infer_shape_func.as_lowered_func_ref());
184205
}
185206

186207
void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate);
@@ -199,6 +220,7 @@ struct CollectBucketStrategyHostFunctionVisitor
199220
ir::Var kernel_args_;
200221
ir::Var kernel_args_num_;
201222
ir::Var kernel_stream_;
223+
ir::Var tensor_shape_args_;
202224
};
203225

204226
} // namespace detail

paddle/cinn/backends/llvm/codegen_llvm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
818818
// TODO(fc500110) hard coding
819819
if (LLVM_WillVarLowerAsPointer(op->name)) {
820820
result = value;
821-
} else if (value->getType()->isPointerTy()) {
821+
} else if (value->getType()->isPointerTy() &&
822+
!value->getType()->getPointerElementType()->isPointerTy()) {
822823
result = Load(value, op->name + "_load");
823824
} else {
824825
result = value;

paddle/cinn/common/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ gather_srcs(
2323
nvgpu_dev_info.cc
2424
integer_set.cc
2525
dim_expr_simplify.cc
26-
dim_expr_converter.cc)
26+
dim_expr_converter.cc
27+
broadcast_tree.cc
28+
dim_expr_util.cc)
2729

2830
cinn_cc_test(test_equation_graph_topo_walker SRCS
2931
equation_graph_topo_walker_test.cc DEPS gtest glog)
@@ -48,8 +50,10 @@ if(WITH_CUDA)
4850
gtest glog)
4951
endif()
5052
if(NOT CINN_ONLY)
53+
cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore)
5154
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
5255
cinncore)
5356
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
5457
cinncore)
58+
cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore)
5559
endif()

0 commit comments

Comments
 (0)