File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed
Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -30,19 +30,16 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
3030 PADDLE_THROW (" SplitIds do not support GPU kernel" );
3131 }
3232
33- const auto * ids_t = ctx.Input <framework::LoDTensor>(" Ids" );
34- auto & ids_dims = ids_t -> dims ();
33+ auto & ids_dims = ctx.Input <framework::LoDTensor>(" Ids" )-> dims ( );
34+ const T* ids = ctx. Input <framework::LoDTensor>( " Ids " )-> data <T> ();
3535 auto outs = ctx.MultiOutput <framework::LoDTensor>(" Out" );
36-
37- const T* ids = ids_t ->data <T>();
38-
3936 const size_t shard_num = outs.size ();
4037
4138 std::vector<std::vector<T>> out_ids;
4239 out_ids.resize (outs.size ());
4340
4441 // split id by their shard_num.
45- for (size_t i = 0 ; i < ids_dims[0 ]; ++i) {
42+ for (int i = 0 ; i < ids_dims[0 ]; ++i) {
4643 T id = ids[i];
4744 size_t shard_id = static_cast <size_t >(id) % shard_num;
4845 out_ids[shard_id].push_back (id);
You can’t perform that action at this time.
0 commit comments