Skip to content

supports distributed classification#18690

Merged
gavin1332 merged 4 commits intoPaddlePaddle:developfrom
gavin1332:distfc
Jul 23, 2019
Merged

supports distributed classification#18690
gavin1332 merged 4 commits intoPaddlePaddle:developfrom
gavin1332:distfc

Conversation

@gavin1332
Copy link
Collaborator

@gavin1332 gavin1332 commented Jul 19, 2019

minimum functional code support for distributed classification training
design doc: http://agroup.baidu.com/paddlepaddle/md/article/1924796

test=develop
test=document_preview
test=develop
test=document_preview
@gavin1332 gavin1332 requested a review from xsrobin July 19, 2019 03:35
Copy link
Member

@guru4elephant guru4elephant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think range is a concept with at least two values, e.g. [0, 10], how about to change index_range to max_index so that we can understand it easily.

AUTHORS.md Outdated
| lcy-seso | Ying Cao |
| cjld | Dun Liang |
| lipeng-unisound | Peng Li |
| liuyi05 | Yi Liu |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you gavin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

囧,忙晕了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int index_range = context.Attr<int>("index_range");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you use max_index or some other specific name for this variable? index_range seems to be a value of width, better to use a detailed variable name.

Copy link
Collaborator Author

@gavin1332 gavin1332 Jul 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attribute name "index_range" is ambiguous indeed, and we need another proper name. In most of cases, users have the variable preserving number of indices, so "max_index" attributed requires user manually subtract 1 from the variable and we also have to recover it for "shard_size" calculation later. So I change the attribute "index_range" to "index_num" as a detailed name, which denotes the number of indices precisely.

PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards,
"shard_id(%d) is not in range [0, %d)", shard_id, nshards);

int shard_range = index_range / nshards;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. shard_width can be better? shard_range can be width of a shared, it also can be how many shards we have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have been renamed to "shard_size".

@gavin1332 gavin1332 requested a review from chengduoZH July 22, 2019 01:52
Copy link
Member

@guru4elephant guru4elephant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please write some technique documents so that it can be easy to use.

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

int shard_size = index_num / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_num);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use assert here? do you check whether it works?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want just make sure the input is in the valid range

@gavin1332
Copy link
Collaborator Author

TODO:replace raw assert by PADDLE_ASSERT_MSG

@gavin1332 gavin1332 requested a review from hutuxian July 22, 2019 09:35
@gavin1332 gavin1332 requested a review from kuke July 22, 2019 09:56
Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gavin1332 gavin1332 merged commit 157211c into PaddlePaddle:develop Jul 23, 2019
@gavin1332 gavin1332 deleted the distfc branch July 31, 2019 08:51
@gavin1332 gavin1332 restored the distfc branch July 31, 2019 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants