Skip to content

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward#17090

Merged
wanghaoshuang merged 2 commits intoPaddlePaddle:developfrom
wanghaoshuang:fix_roi_per_trans
Apr 25, 2019
Merged

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward#17090
wanghaoshuang merged 2 commits intoPaddlePaddle:developfrom
wanghaoshuang:fix_roi_per_trans

Conversation

@wanghaoshuang
Copy link
Contributor

test=develop

T* in_grad_data) {
CUDA_1D_KERNEL_LOOP(index, out_size * 4) {
int in_idx = out2in_idx_data[index];
if (in_idx > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be >= here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

out2in_w->mutable_data<T>({out->numel(), 4}, ctx.GetPlace());

math::SetConstant<platform::CUDADeviceContext, int> init;
init(ctx.cuda_device_context(), out2in_idx, static_cast<int>(-1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out2in_w 需要初始化么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据这行的条件:
如果out2in_idx[i] == -1, 则out2in_w[i]就不会被用到

所以只需要将out2in_idx中元素初始化为-1即可

self.outputs['Out2InIdx'] = np.zeros(
[np.product(self.outputs['Out'].shape), 4]).astype("int32")
self.outputs['Out2InWeights'] = np.zeros(
[np.product(self.outputs['Out'].shape), 4]).astype("float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_roi_pooling里面test_checkout()也会check Intermediate的output,此处是不是可以修改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roi_pooling的infer shape里有对 Intermediate output的ENFORCE CHECK,所以单测的test_checkout()也需要有Intermediate output。

当前pr只修改了CUDA kernel, CPU kernel的计算没有用到Intermediate output, 所以单测里也没有check.

test=develop
Copy link
Contributor

@heavengate heavengate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghaoshuang wanghaoshuang merged commit 55ce36e into PaddlePaddle:develop Apr 25, 2019
sneaxiy pushed a commit to sneaxiy/Paddle that referenced this pull request Apr 28, 2019
# The first commit's message is:
remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (PaddlePaddle#17066)

# This is the 2nd commit message:

Fleet unify distributed training (PaddlePaddle#16791)

* implement distributed transpiler with fleet
# This is the 3rd commit message:

ParallelDyGraph with GPU collective mode (PaddlePaddle#16827)

implement dygraph.parallel.DataParallel to hook reduce op.

# This is the 4th commit message:

Init mixed precision training interface (PaddlePaddle#16856)

* Init mixed precision training interface

* Add fp16 test script

test=develop

* All initializers support float16

test=develop

* Code cleanup & add more code annotations

test=develop

* Update API spec

test=develop

* Add usage example in doc

test=develop

# This is the 5th commit message:

fix reference_count_pass,test=develop (PaddlePaddle#17060)

test=develop
# This is the 6th commit message:

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (PaddlePaddle#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop

# This is the 7th commit message:

remove unnecessary prepare_data (PaddlePaddle#17080)

test=develop
# This is the 8th commit message:

fix interpolate cu. test=develop (PaddlePaddle#17101)

# This is the 9th commit message:

test=develop, double backward leaky_relu (PaddlePaddle#17067)

backward of backward: leaky_relu
# This is the 10th commit message:

fix fuse optimizer ops (PaddlePaddle#17102)

test=develop
# This is the 11th commit message:

truncated_gaussian_random supported in distributed training, test=develop (PaddlePaddle#17091)

# This is the 12th commit message:

 Detailed coordinate description for yolov3 loss (PaddlePaddle#17007)

* Detailed coordinate description for yolov3 loss

test=develop

* modified api.spec

test=develop

* modified loss name

* fix api.spec

test=develop

* polish description

test=develop

* modified api.spec

test=develop

# This is the 13th commit message:

fix test_weight_decay (PaddlePaddle#17109)

test=develop
# This is the 14th commit message:

Path flag (PaddlePaddle#17105)

* fix python/paddle/fluid/__init__.py detecting problems
sneaxiy added a commit that referenced this pull request Apr 28, 2019
* refine_dropout_mem,test=develop

* # This is a combination of 14 commits.
# The first commit's message is:
remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066)

# This is the 2nd commit message:

Fleet unify distributed training (#16791)

* implement distributed transpiler with fleet
# This is the 3rd commit message:

ParallelDyGraph with GPU collective mode (#16827)

implement dygraph.parallel.DataParallel to hook reduce op.

# This is the 4th commit message:

Init mixed precision training interface (#16856)

* Init mixed precision training interface

* Add fp16 test script

test=develop

* All initializers support float16

test=develop

* Code cleanup & add more code annotations

test=develop

* Update API spec

test=develop

* Add usage example in doc

test=develop

# This is the 5th commit message:

fix reference_count_pass,test=develop (#17060)

test=develop
# This is the 6th commit message:

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop

# This is the 7th commit message:

remove unnecessary prepare_data (#17080)

test=develop
# This is the 8th commit message:

fix interpolate cu. test=develop (#17101)

# This is the 9th commit message:

test=develop, double backward leaky_relu (#17067)

backward of backward: leaky_relu
# This is the 10th commit message:

fix fuse optimizer ops (#17102)

test=develop
# This is the 11th commit message:

truncated_gaussian_random supported in distributed training, test=develop (#17091)

# This is the 12th commit message:

 Detailed coordinate description for yolov3 loss (#17007)

* Detailed coordinate description for yolov3 loss

test=develop

* modified api.spec

test=develop

* modified loss name

* fix api.spec

test=develop

* polish description

test=develop

* modified api.spec

test=develop

# This is the 13th commit message:

fix test_weight_decay (#17109)

test=develop
# This is the 14th commit message:

Path flag (#17105)

* fix python/paddle/fluid/__init__.py detecting problems
@wanghaoshuang wanghaoshuang deleted the fix_roi_per_trans branch May 20, 2022 03:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants