Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions paddle/fluid/operators/lookup_table_v2_op_mlu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
constexpr int64_t kNoPadding = -1;

template <typename T>
class LookupTableV2MLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); // int tensor
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W");

auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(
table_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument("mlu only accept LoDTensor"));
output_t->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc ids_desc(*ids_t);
MLUCnnlTensorDesc table_desc(*table_t);
MLUCnnlTensorDesc output_desc(*output_t);

int64_t padding_idx = ctx.Attr<int64_t>("padding_idx");
if (padding_idx == kNoPadding) {
MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
table_desc.get(), GetBasePtr(table_t),
ids_desc.get(), GetBasePtr(ids_t),
output_desc.get(), GetBasePtr(output_t));
} else {
Tensor tmp_table_t(table_t->type());
tmp_table_t.mutable_data<T>(table_t->dims(), ctx.GetPlace());

Tensor index;
index.mutable_data<int32_t>({1, 1}, ctx.GetPlace());
auto idx_value = static_cast<int32_t>(padding_idx);
MLUCnnlTensorDesc index_desc(index);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &idx_value, index_desc.get(),
GetBasePtr(&index));

auto update_dim = phi::make_ddim({1, table_t->dims()[1]});
Tensor update;
update.mutable_data<T>(update_dim, ctx.GetPlace());

auto update_value = static_cast<T>(0);
MLUCnnlTensorDesc update_desc(update);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &update_value,
update_desc.get(), GetBasePtr(&update));

MLUCnnlTensorDesc tmp_table_desc(tmp_table_t);
MLUCnnl::ScatterNd(
ctx, CNNL_SCATTERND_UPDATE, index_desc.get(), GetBasePtr(&index),
update_desc.get(), GetBasePtr(&update), table_desc.get(),
GetBasePtr(table_t), tmp_table_desc.get(), GetBasePtr(&tmp_table_t));

MLUCnnl::GatherFunctor(ctx, /*axis=*/0, /*batch_dims=*/0,
tmp_table_desc.get(), GetBasePtr(&tmp_table_t),
ids_desc.get(), GetBasePtr(ids_t),
output_desc.get(), GetBasePtr(output_t));
}
}
};

template <typename T>
class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *table_grad_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
table_grad_t->mutable_data<T>(ctx.GetPlace());

int padding_idx = static_cast<int>(ctx.Attr<int64_t>("padding_idx"));

Tensor ids_int32(ids_t->dtype());
if (ids_t->dtype() != DataType::INT32) {
ids_int32.mutable_data<int>(ids_t->dims(), ctx.GetPlace());
MLUCnnlTensorDesc ids_desc(*ids_t);
MLUCnnlTensorDesc ids_int32_desc(ids_int32);
auto cast_type = GetCastDataType(ids_t->dtype(), DataType::INT32);
MLUCnnl::Cast(ctx, cast_type, ids_desc.get(), GetBasePtr(ids_t),
ids_int32_desc.get(), GetBasePtr(&ids_int32));
} else {
ids_int32 = *ids_t;
}

MLUCnnlTensorDesc ids_int32_desc(ids_int32);
MLUCnnlTensorDesc output_grad_desc(*output_grad_t);
MLUCnnlTensorDesc table_grad_desc(*table_grad_t);

MLUCnnl::EmbeddingBackward(ctx, padding_idx, false, ids_int32_desc.get(),
GetBasePtr(&ids_int32), output_grad_desc.get(),
GetBasePtr(output_grad_t), table_grad_desc.get(),
GetBasePtr(table_grad_t));
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_MLU_KERNEL(lookup_table_v2, ops::LookupTableV2MLUKernel<float>,
ops::LookupTableV2MLUKernel<int>,
ops::LookupTableV2MLUKernel<plat::float16>);

REGISTER_OP_MLU_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradMLUKernel<float>,
ops::LookupTableV2GradMLUKernel<int>,
ops::LookupTableV2GradMLUKernel<plat::float16>);
46 changes: 36 additions & 10 deletions paddle/fluid/operators/mlu/mlu_baseop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
return cast_type;
}

cnnlCastDataType_t GetCastDataType(const DataType& src_type,
const DataType& dst_type) {
return GetCastDataType(framework::TransToProtoVarType(src_type),
framework::TransToProtoVarType(dst_type));
}

bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type) {
for (auto it = MLU_SUPPORTED_CAST_TYPE.begin();
it != MLU_SUPPORTED_CAST_TYPE.end(); ++it) {
Expand Down Expand Up @@ -2713,17 +2719,16 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
output_desc, output));
}

/* static */ void MLUCnnl::ScatterNd(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t indices_desc,
const void* indices,
const cnnlTensorDescriptor_t updates_desc,
const void* updates,
const cnnlTensorDescriptor_t output_desc,
void* output) {
/* static */ void MLUCnnl::ScatterNd(
const ExecutionContext& ctx, cnnlScatterNdMode_t mode,
const cnnlTensorDescriptor_t indices_desc, const void* indices,
const cnnlTensorDescriptor_t updates_desc, const void* updates,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlScatterNd(handle, indices_desc, indices,
updates_desc, updates, output_desc,
output));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlScatterNd_v2(handle, mode, indices_desc, indices, updates_desc,
updates, input_desc, input, output_desc, output));
}

/* static */ void MLUCnnl::BitWise(
Expand Down Expand Up @@ -2777,5 +2782,26 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
cnnlReciprocal(handle, input_desc, input, output_desc, output));
}

/* static */ void MLUCnnl::EmbeddingBackward(
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
const cnnlTensorDescriptor_t indices_desc, const void* indices,
const cnnlTensorDescriptor_t diff_desc, const void* diff,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);

size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetEmbeddingBackwardWorkspaceSize(
handle, diff_desc, output_desc, scale_grad_by_freq, &workspace_size));

auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());

PADDLE_ENFORCE_MLU_SUCCESS(cnnlEmbeddingBackward(
handle, padding_idx, scale_grad_by_freq, indices_desc, indices, diff_desc,
diff, workspace_ptr, workspace_size, output_desc, output));
}

} // namespace operators
} // namespace paddle
14 changes: 13 additions & 1 deletion paddle/fluid/operators/mlu/mlu_baseop.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ const std::map<std::pair<VT::Type, VT::Type>, cnnlCastDataType_t>

cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
const VT::Type& dst_type);

cnnlCastDataType_t GetCastDataType(const DataType& src_type,
const DataType& dst_type);

bool MLUSupportsCast(const VT::Type& src_type, const VT::Type& dst_type);

cnnlDeviceType_t GetCnnlDev(int dev_ordinal);
Expand Down Expand Up @@ -1202,11 +1206,13 @@ class MLUCnnl {
const void* k, const int k_int,
const cnnlTensorDescriptor_t output_desc, void* output);

static void ScatterNd(const ExecutionContext& ctx,
static void ScatterNd(const ExecutionContext& ctx, cnnlScatterNdMode_t mode,
const cnnlTensorDescriptor_t indices_desc,
const void* indices,
const cnnlTensorDescriptor_t updates_desc,
const void* updates,
const cnnlTensorDescriptor_t input_desc,
const void* input,
const cnnlTensorDescriptor_t output_desc, void* output);

static void BitWise(const ExecutionContext& ctx,
Expand All @@ -1227,6 +1233,12 @@ class MLUCnnl {
const void* input,
const cnnlTensorDescriptor_t output_desc,
void* output);

static void EmbeddingBackward(
const ExecutionContext& ctx, int padding_idx, bool scale_grad_by_freq,
const cnnlTensorDescriptor_t indices_desc, const void* indices,
const cnnlTensorDescriptor_t diff_desc, const void* diff,
const cnnlTensorDescriptor_t output_desc, void* output);
};

template <typename T>
Expand Down
95 changes: 95 additions & 0 deletions paddle/fluid/operators/unstack_op_mlu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"

namespace paddle {
namespace operators {

template <typename T>
class UnStackMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto out = ctx.MultiOutput<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += x->dims().size();
int num = x->dims()[axis];

std::vector<MLUCnnlTensorDesc> out_descs;
std::vector<cnnlTensorDescriptor_t> out_raw_descs;
std::vector<void *> out_ptrs;
std::vector<int64_t> new_dims = phi::vectorize(x->dims());
new_dims[axis] = 1;
for (int i = 0; i < num; i++) {
out[i]->mutable_data<T>(ctx.GetPlace());
out_descs.emplace_back(MLUCnnlTensorDesc(new_dims.size(), new_dims.data(),
ToCnnlDataType<T>()));
out_raw_descs.push_back(out_descs.back().get());
out_ptrs.push_back(GetBasePtr(out[i]));
}

MLUCnnlTensorDesc x_desc(*x);
MLUCnnl::Split(ctx, num, axis, x_desc.get(), GetBasePtr(x),
out_raw_descs.data(), out_ptrs.data());
}
};

template <typename T>
class UnStackGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.MultiInput<Tensor>(framework::GradVarName("Y"));
auto *y = ctx.Output<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int num = static_cast<int>(x.size());

std::vector<MLUCnnlTensorDesc> x_descs;
std::vector<cnnlTensorDescriptor_t> x_raw_descs;
std::vector<const void *> x_ptrs;
for (int i = 0; i < num; i++) {
if (x[i]->dims().size() != 0) {
std::vector<int64_t> in_dims = phi::vectorize(x[i]->dims());
in_dims.insert(in_dims.begin() + axis, 1);
x_descs.emplace_back(MLUCnnlTensorDesc(in_dims.size(), in_dims.data(),
ToCnnlDataType<T>()));
} else {
int input_dims = 1;
x_descs.emplace_back(
MLUCnnlTensorDesc(1, &input_dims, ToCnnlDataType<T>()));
}
x_raw_descs.push_back(x_descs.back().get());
x_ptrs.push_back(GetBasePtr(x[i]));
}
y->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc y_desc(*y);
MLUCnnl::Concat(ctx, num, axis, x_raw_descs.data(), x_ptrs.data(),
y_desc.get(), GetBasePtr(y));
}
};

} // namespace operators
} // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

REGISTER_OP_MLU_KERNEL(unstack, ops::UnStackMLUKernel<float>,
ops::UnStackMLUKernel<plat::float16>);

REGISTER_OP_MLU_KERNEL(unstack_grad, ops::UnStackGradMLUKernel<float>,
ops::UnStackGradMLUKernel<plat::float16>);
Loading