Skip to content

Commit 4e85793

Browse files
authored
修复sigmoid_focal_loss的npu适配层 (#3323)
1 parent 1d8f928 commit 4e85793

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
1717
int64_t weight_size = weight.size(0);
1818
at::Tensor weight_y = at::ones_like(input);
1919
if (weight_size > 0) {
20-
weight_y = at::broadcast_to(weight, input.sizes());
20+
at::Tensor weight_selected = weight.gather(0, target);
21+
weight_selected = weight_selected.unsqueeze(1);
22+
weight_y = weight_selected.expand_as(input);
2123
}
2224
OpCommand cmd;
2325
string reduction = "none";

0 commit comments

Comments
 (0)