Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2746,7 +2746,7 @@ struct RandIntOpTranscriber : public OpTranscriber {
paddle::dialect::DenseTensorTypeStorage::Dim dim =
common::make_ddim(var->GetShape());
paddle::dialect::DenseTensorTypeStorage::DataLayout layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED;
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
pir::Type translated_var_type = paddle::dialect::DenseTensorType::get(
Expand Down
89 changes: 43 additions & 46 deletions paddle/fluid/ir_adaptor/translator/type_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,48 @@ using DenseTensorType = paddle::dialect::DenseTensorType;
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
using SelectedRowsType = paddle::dialect::SelectedRowsType;
using SelectedRowsTypeStorage = paddle::dialect::SelectedRowsTypeStorage;
using DataLayout = DenseTensorTypeStorage::DataLayout;
using LoD = DenseTensorTypeStorage::LoD;

TypeTranslator::TypeTranslator() {
const auto& HandleTensor = [&](pir::IrContext* ctx,
const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR";
const pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
const auto dim = common::make_ddim(var_desc.GetShape());
const auto layout = DataLayout::NCHW;
const LoD lod = {};
const size_t offset = 0;
return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset);
};
const auto& HandleTensorArray = [&](pir::IrContext* ctx,
const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY";
const pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
const auto dims = common::make_ddim(var_desc.GetShape());
const auto layout = DataLayout::NCHW;
return paddle::dialect::DenseTensorArrayType::get(ctx, dtype, dims, layout);
};

const auto& HandleSelectedRows = [&](pir::IrContext* ctx,
const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from SELECTED_ROWS";
const pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
const auto dim = common::make_ddim(var_desc.GetShape());
const auto layout = DataLayout::NCHW;
const LoD lod = {};
const size_t offset = 0;
pir::Type SelectedRows =
SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset);
return SelectedRows;
};

handlers = {
{VarType::BOOL,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
Expand Down Expand Up @@ -81,52 +121,9 @@ TypeTranslator::TypeTranslator() {
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
return pir::Complex128Type::get(ctx);
}},
{VarType::LOD_TENSOR,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR";

pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
DenseTensorTypeStorage::Dim dim =
common::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED;
DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset);
}},
{VarType::LOD_TENSOR_ARRAY,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY";
pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
phi::DDim dims = common::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED;

return paddle::dialect::DenseTensorArrayType::get(
ctx, dtype, dims, layout);
}},
{VarType::SELECTED_ROWS,
[&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from SELECTED_ROWS";

pir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);

SelectedRowsTypeStorage::Dim dim =
common::make_ddim(var_desc.GetShape());
SelectedRowsTypeStorage::DataLayout layout =
SelectedRowsTypeStorage::DataLayout::UNDEFINED;
SelectedRowsTypeStorage::LoD lod = {};
size_t offset = 0;
pir::Type SelectedRows =
SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset);
return SelectedRows;
}},
{VarType::LOD_TENSOR, HandleTensor},
{VarType::LOD_TENSOR_ARRAY, HandleTensorArray},
{VarType::SELECTED_ROWS, HandleSelectedRows},
};
}

Expand Down