supports distributed classification#18690
Conversation
test=develop
test=develop test=document_preview
test=develop test=document_preview
guru4elephant
left a comment
There was a problem hiding this comment.
I think range is a concept with at least two values, e.g. [0, 10], how about to change index_range to max_index so that we can understand it easily.
AUTHORS.md
Outdated
| | lcy-seso | Ying Cao | | ||
| | cjld | Dun Liang | | ||
| | lipeng-unisound | Peng Li | | ||
| | liuyi05 | Yi Liu | |
| void Compute(const framework::ExecutionContext& context) const override { | ||
| auto* in = context.Input<LoDTensor>("X"); | ||
| auto* out = context.Output<LoDTensor>("Out"); | ||
| int index_range = context.Attr<int>("index_range"); |
There was a problem hiding this comment.
could you use max_index or some other specific name for this variable? index_range seems to be a value of width, better to use a detailed variable name.
There was a problem hiding this comment.
the attribute name "index_range" is ambiguous indeed, and we need another proper name. In most of cases, users have the variable preserving number of indices, so "max_index" attributed requires user manually subtract 1 from the variable and we also have to recover it for "shard_size" calculation later. So I change the attribute "index_range" to "index_num" as a detailed name, which denotes the number of indices precisely.
| PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, | ||
| "shard_id(%d) is not in range [0, %d)", shard_id, nshards); | ||
|
|
||
| int shard_range = index_range / nshards; |
There was a problem hiding this comment.
same as above. shard_width can be better? shard_range can be width of a shared, it also can be how many shards we have.
There was a problem hiding this comment.
have been renamed to "shard_size".
test=document_preview test=develop
guru4elephant
left a comment
There was a problem hiding this comment.
LGTM, please write some technique documents so that it can be easy to use.
| int shard_size = index_num / nshards; | ||
| int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (idx < numel) { | ||
| assert(in_data[idx] >= 0 && in_data[idx] < index_num); |
There was a problem hiding this comment.
Why do you use assert here? do you check whether it works?
There was a problem hiding this comment.
I want just make sure the input is in the valid range
|
TODO:replace raw |
minimum functional code support for distributed classification training
design doc: http://agroup.baidu.com/paddlepaddle/md/article/1924796