Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lite/api/cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class LITE_API Predictor {
std::map<TargetType, std::shared_ptr<void>> target_configs_;
std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_;
Scope* exec_scope_;
Scope* exec_scope_{nullptr};
std::shared_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::vector<std::string> input_names_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1906,7 +1906,7 @@ class XPUMultiEncoderFuser {
if (is_qkv_already_fusion_) {
end = i + 1;
}
scope->NewTensor(update_tag);
scope->MutableParent()->NewTensor(update_tag);
// Update weight, including tranpose\convert type\fuse qkv
// weight\findmax.
update_weight(scope,
Expand Down
20 changes: 16 additions & 4 deletions lite/core/optimizer/mir/type_target_cast_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,15 @@ void TypeTargetTransformPass::AddInputIoCopyInst(
// So there will be a new Argument node and a new IoCopy Statement Node.

CHECK(in->IsArg());
auto io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string io_copy_output_name;
if (in_persist) {
io_copy_output_name =
string_format("%s/target_trans_persistable", in->AsArg().name.c_str());
} else {
io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
}

if (copied_nodes->count(in->AsArg().name)) {
// Remove the old link
Expand All @@ -292,7 +299,13 @@ void TypeTargetTransformPass::AddInputIoCopyInst(
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Create the new var manually.
auto new_var = inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
Variable* new_var = nullptr;
if (in_persist) {
new_var = inst_node->AsStmt().op()->scope()->MutableParent()->Var(
io_copy_output_name);
} else {
new_var = inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
}
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(),
Expand All @@ -316,7 +329,6 @@ void TypeTargetTransformPass::AddInputIoCopyInst(
}
auto* io_copy_inst = graph->NewInstructNode();

bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy";
io_copy_output_arg->AsArg().is_persist = in_persist;
// create Op and kernels.
Expand Down
7 changes: 6 additions & 1 deletion lite/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,12 @@ void Program::PrepareWorkspace(
// Create tensors or weights from variable description.
if (!var_desc->Persistable()) {
vars_.push_back(var_name);
auto* var = exec_scope_->Var(var_name);
Variable* var = nullptr;
if (var_name.find("/target_trans_persistable") != std::string::npos) {
var = scope_->Var(var_name);
} else {
var = exec_scope_->Var(var_name);
}
if (var_type == lite::VarDescAPI::Type::LOD_TENSOR) {
const auto& var_data_type =
VarDescType2PrecisionType(var_desc->GetDataType());
Expand Down