Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions oneflow/core/job_rewriter/multi_tensor_model_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,6 @@ Maybe<void> MultiTensorModelUpdatePass::Apply(const OpGraph& op_graph,
}
const user_op::UserOpConfWrapper model_update_user_conf(
find_model_update_update_node->op().op_conf());
// Multi tensor update pass only support for CUDA currently.
if (find_model_update_update_node->parallel_desc().device_type() != DeviceType::kCUDA) {
continue;
}

// Multi tensor update pass only support Data Parallel.
bool if_data_parallel = true;
Expand Down
4 changes: 0 additions & 4 deletions python/oneflow/nn/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,6 @@ def __init__(
warnings.warn("Fused Adamw is not supported when amsgrad=True.")
param_group["fused"] = False

if param_group["fused"] and not param.is_cuda:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个去掉,是不是cpu device会有影响,是否需要用param.is_cpu判断一下?

warnings.warn("Fused Adamw only support cuda parameters.")
param_group["fused"] = False

self._op_with_amsgrad = (
flow.stateful_op("adam_update")
.Input("model")
Expand Down