Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions paddle/fluid/lite/api/light_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ class LightPredictor {

// Create the kernels of the target places, and filter out the specific
// kernel with the target alias.
for (auto& op : program.ops_) {
lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
for (auto& op : program.ops()) {
auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr);
std::string op_type, alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
Expand All @@ -89,8 +88,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope_);
program_->set_exec_scope(program.exec_scope_);
CHECK(program.exec_scope());
program_->set_exec_scope(program.exec_scope());
}

private:
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/lite/kernels/x86/sgd_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
using param_t = operators::ActivationParam;

void Run() override {
auto &context = context_->As<X86Context>();
auto &context = ctx_->As<X86Context>();
auto &sgd_param = *param_.get_mutable<operators::SGDParam>();
CHECK(context.x86_device_context);
CHECK(context.x86_device_context());

// param.Out->template mutable_data<T>();

Expand All @@ -45,12 +45,12 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
PADDLE_ENFORCE_EQ(grad->numel(), sz);

paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>();
const T *param_data = param->data<T>();
const T *grad_data = grad->data<T>();
const T *lr = learning_rate->template data<T>();
const T *param_data = param->template data<T>();
const T *grad_data = grad->template data<T>();
int64_t rows_idx = 0;
T *out_data =
param_out->mutable_data<T>(context.x86_device_context->GetPlace());
T *out_data = param_out->template mutable_data<T>(
context.x86_device_context()->GetPlace());

auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
Expand Down