Skip to content

Commit 5caa6fc

Browse files
MingMingShangTianchenwhqlYuanRishengShixiaowei02
authored
[PTen] Add variable transform to/from ptenTensor and add cast kernel (#36916)
* add cast kernel * add cast cuda kernel * add cast kernel * make cast kernel output dtype undefined * get cast dtype from vardesc * move cast to manipulation and add test case * add castinfershape * avoid reinitilaze variable * InitializeVariable support datatype * merge develop branch * fix merge bug * revert modify initializeVariable * revert modify on InitializeVariable * revert modify on InitializeVariable * mutable support reset dtype * enable make pten tensor from variable when def_arg.type is undefined * fix build pten ctx start_idx error * copy pten out tensor to variable * merge develop branch * fix non pten kernel cast failed * add reset allocation place for remake tensor * fix inplace realloc error * add mutable on pten kernles and remove unused cast files * rename function names * fix output type error * fix conflict with develop branch * set data type to variable with pten's dtype * fix test_cast_api type mismatch * densorTensro mutable_data support 0 bytes value * fix the inplace bug of reshape kernel * fix pten.backend != variable.place when moving storage, palce mismatch bug * fix conflict with develop branch * Fix bug of paddle::experimental::MovesStorage * fix ReMakePtenDenseTensor place mismatch bug * Revert "fix ReMakePtenDenseTensor place mismatch bug" This reverts commit 8633603. * fix ReMakePtenDenseTensor place mismatch bug * reverts the set_lod interface, test=develop * modify by the review options * modify error message * add & for const input arguments * add reference in params * elementwise_sub add mutable_data * fix ResetHolderWithType check size bug * add dependence pten_tensor to test_cast_api object * remove unused code to pass ci coverage Co-authored-by: Chen Weihang <[email protected]> Co-authored-by: YuanRisheng <[email protected]> Co-authored-by: shixiaowei02 <[email protected]>
1 parent 075c22f commit 5caa6fc

40 files changed

+837
-149
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11831183
}
11841184
BuildPtenKernelContext(*runtime_ctx, dev_ctx);
11851185
(*pt_kernel_)(pt_kernel_context_.get());
1186+
1187+
WriteBackToOutputs(runtime_ctx);
1188+
11861189
pt_kernel_context_->ClearData();
11871190
} else {
11881191
(*kernel_func_)(
@@ -1808,50 +1811,98 @@ void OperatorWithKernel::BuildPtenKernelContext(
18081811
for (size_t i = 0; i < input_names.size(); ++i) {
18091812
auto& in_def = input_defs.at(i);
18101813
auto& ins_vector = ctx.inputs.at(input_names[i]);
1811-
if (pt_kernel_context_->InputsSize() <= i) {
1814+
1815+
// calcute the start and end index of the input tensors
1816+
size_t start_idx =
1817+
(i == 0 ? 0 : pt_kernel_context_->InputRangeAt(i - 1).second);
1818+
size_t end_idx = start_idx + ins_vector.size();
1819+
1820+
// The current size of input/output in pt_kernel_context_ is at least equal
1821+
// the start_idx. For the reason of reusing the allocted of inputs or
1822+
// outputs in pt_kernel_context_, the current size of input/output can be
1823+
// greater then the index of which the tensort wanted to set to, so it will
1824+
// use ReMakePtenDenseTensorFromVar to make pten tensor.
1825+
if (pt_kernel_context_->InputsSize() == start_idx) {
18121826
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
18131827
for (auto* var : ins_vector) {
18141828
tmp_inputs.emplace_back(
18151829
experimental::MakePtenTensorBaseFromVar(*var, in_def));
18161830
}
18171831
pt_kernel_context_->EmplaceBackInputs(std::move(tmp_inputs));
1818-
} else {
1832+
} else if (pt_kernel_context_->InputsSize() > start_idx) {
18191833
size_t input_size = pt_kernel_context_->InputsSize();
18201834
for (size_t j = 0; j < ins_vector.size(); ++j) {
1821-
if (input_size > i + j) {
1835+
if (input_size > start_idx + j) {
18221836
experimental::ReMakePtenDenseTensorFromVar(
18231837
*ins_vector[j], in_def,
1824-
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(i + j));
1838+
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(start_idx +
1839+
j));
1840+
// TODO(chentianyu03): When multi input kernel, open this code
1841+
/*
1842+
} else {
1843+
pt_kernel_context_->EmplaceBackInputWithoutSetRange(
1844+
experimental::MakePtenTensorBaseFromVar(*ins_vector[j],
1845+
in_def));
1846+
*/
18251847
}
1826-
// TODO(chenweihang): adapt multi-input case later
18271848
}
18281849
pt_kernel_context_->MutableInputRangeAt(i) =
1829-
std::make_pair(i, i + ins_vector.size());
1850+
std::make_pair(start_idx, end_idx);
1851+
} else {
1852+
PADDLE_THROW(platform::errors::PreconditionNotMet(
1853+
"Error start index when trying to set new tensor to inputs, start "
1854+
"index is `%d`, but current pt_kernel_context_.inputs.size() is "
1855+
"`%d`.",
1856+
start_idx, pt_kernel_context_->InputsSize()));
18301857
}
18311858
}
18321859

18331860
for (size_t i = 0; i < output_names.size(); ++i) {
18341861
auto& out_def = output_defs.at(i);
18351862
auto& outs_vector = ctx.outputs.at(output_names[i]);
1836-
if (pt_kernel_context_->OutputsSize() <= i) {
1863+
1864+
size_t start_idx =
1865+
(i == 0 ? 0 : pt_kernel_context_->OutputRangeAt(i - 1).second);
1866+
size_t end_idx = start_idx + outs_vector.size();
1867+
1868+
// The current size of input/output in pt_kernel_context_ is at least equal
1869+
// the start_idx. For the reason of reusing the allocted of inputs or
1870+
// outputs in pt_kernel_context_, the current size of input/output can be
1871+
// greater then the index of which the tensort wanted to set to, so it will
1872+
// use ReMakePtenDenseTensorFromVar to make pten tensor.
1873+
if (pt_kernel_context_->OutputsSize() == start_idx) {
18371874
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
18381875
for (auto* var : outs_vector) {
18391876
tmp_outputs.emplace_back(
18401877
experimental::MakePtenTensorBaseFromVar(var, out_def));
18411878
}
18421879
pt_kernel_context_->EmplaceBackOutputs(std::move(tmp_outputs));
1843-
} else {
1880+
} else if (pt_kernel_context_->OutputsSize() > start_idx) {
18441881
size_t output_size = pt_kernel_context_->OutputsSize();
18451882
for (size_t j = 0; j < outs_vector.size(); ++j) {
1846-
if (output_size > i + j) {
1883+
if (output_size > start_idx + j) {
18471884
experimental::ReMakePtenDenseTensorFromVar(
18481885
outs_vector[j], out_def,
1849-
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(i + j));
1886+
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
1887+
j));
1888+
1889+
// TODO(chentianyu03): When multi output kernel, open this code
1890+
/*
1891+
} else {
1892+
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
1893+
experimental::MakePtenTensorBaseFromVar(outs_vector[j],
1894+
out_def));
1895+
*/
18501896
}
1851-
// TODO(chenweihang): adapt multi-output case later
18521897
}
18531898
pt_kernel_context_->MutableOutputRangeAt(i) =
1854-
std::make_pair(i, i + outs_vector.size());
1899+
std::make_pair(start_idx, end_idx);
1900+
} else {
1901+
PADDLE_THROW(platform::errors::PreconditionNotMet(
1902+
"Error start index when trying to set new tensor to inputs, start "
1903+
"index is `%d`, but current pt_kernel_context_.outputs.size() is "
1904+
"`%d`.",
1905+
start_idx, pt_kernel_context_->OutputsSize()));
18551906
}
18561907
}
18571908

@@ -1883,14 +1934,23 @@ void OperatorWithKernel::BuildPtenKernelContext(
18831934
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
18841935
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
18851936
} else if (attr_defs[i].type_index ==
1886-
std::type_index(typeid(std::vector<int64_t>)) &&
1887-
std::type_index(attr.type()) ==
1888-
std::type_index(typeid(std::vector<int>))) {
1889-
// Emplace Back Attr according to the type of Pten_Kernel args.
1890-
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
1891-
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
1892-
vector_int_attr.end());
1893-
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr);
1937+
std::type_index(typeid(pten::DataType))) {
1938+
auto data_type = pten::TransToPtenDataType(
1939+
static_cast<framework::proto::VarType::Type>(
1940+
BOOST_GET_CONST(int, attr)));
1941+
pt_kernel_context_->EmplaceBackAttr(data_type);
1942+
} else if (attr_defs[i].type_index ==
1943+
std::type_index(typeid(std::vector<int64_t>))) {
1944+
if (std::type_index(attr.type()) ==
1945+
std::type_index(typeid(std::vector<int>))) {
1946+
// Emplace Back Attr according to the type of Pten_Kernel args.
1947+
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
1948+
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
1949+
vector_int_attr.end());
1950+
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr);
1951+
}
1952+
// TODO(YuanRisheng) Need support vector<int64_t> attr
1953+
18941954
} else {
18951955
PADDLE_THROW(platform::errors::Unimplemented(
18961956
"unsupported cast op attribute `%s` when construct "
@@ -1901,5 +1961,26 @@ void OperatorWithKernel::BuildPtenKernelContext(
19011961
}
19021962
}
19031963

1964+
void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
1965+
// auto& input_names = std::get<0>(pt_kernel_signature_->args);
1966+
// auto& attr_names = std::get<1>(pt_kernel_signature_->args);
1967+
auto& output_names = std::get<2>(pt_kernel_signature_->args);
1968+
1969+
// pt_kernel_context_
1970+
1971+
for (size_t i = 0; i < output_names.size(); ++i) {
1972+
auto& outs_vector = ctx->outputs.at(output_names[i]);
1973+
1974+
auto& range_pair = pt_kernel_context_->OutputRangeAt(i);
1975+
auto pten_outs =
1976+
pt_kernel_context_->MutableOutputBetween<pten::DenseTensor>(
1977+
range_pair.first, range_pair.second);
1978+
1979+
for (size_t j = 0; j < pten_outs.size(); ++j) {
1980+
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
1981+
}
1982+
}
1983+
}
1984+
19041985
} // namespace framework
19051986
} // namespace paddle

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,8 @@ class OperatorWithKernel : public OperatorBase {
589589
void BuildPtenKernelContext(const RuntimeContext& ctx,
590590
platform::DeviceContext* dev_ctx) const;
591591

592+
void WriteBackToOutputs(RuntimeContext* ctx) const;
593+
592594
protected:
593595
mutable std::unique_ptr<OpKernelType> kernel_type_;
594596
mutable std::unique_ptr<OpKernelFunc> kernel_func_;

paddle/fluid/framework/tensor.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,12 @@ void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) {
204204
}
205205

206206
void Tensor::ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
207-
const proto::VarType::Type type) {
208-
ResetHolder(holder);
207+
const proto::VarType::Type& type) {
209208
type_ = type;
209+
ResetHolder(holder);
210210
}
211211

212+
void Tensor::set_type(const proto::VarType::Type& type) { type_ = type; }
213+
212214
} // namespace framework
213215
} // namespace paddle

paddle/fluid/framework/tensor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ class Tensor {
271271
void ResetHolder(std::shared_ptr<memory::Allocation> holder);
272272

273273
void ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
274-
const proto::VarType::Type type);
274+
const proto::VarType::Type& type);
275+
276+
void set_type(const proto::VarType::Type& type);
275277

276278
TensorInplaceVersion& InplaceVersionCounter() {
277279
return *inplace_version_counter_;

0 commit comments

Comments
 (0)