@@ -25,7 +25,7 @@ void KeyPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
2525 int num_keys = bottom[1 ]->shape (0 );
2626 has_keys_.clear ();
2727 key_start_.clear ();
28- key_end_ .clear ();
28+ key_len_ .clear ();
2929
3030 vector<Blob<Dtype>*> pooling_top;
3131 pooling_top.push_back (top[0 ]);
@@ -42,27 +42,21 @@ void KeyPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
4242 has_keys_.push_back (current_key);
4343 key_start_.push_back (0 );
4444
45- int j = 0 ;
4645 for (int i = 1 ; i < num_keys; ++i) {
4746 if (keys[i] != current_key) {
48- j++;
49- key_end_.push_back (i);
50-
51- largest_key_set_ =
52- std::max (key_end_[j - 1 ] - key_start_[j - 1 ], largest_key_set_);
47+ key_len_.push_back (i - key_start_[key_start_.size ()-1 ]);
5348
5449 current_key = keys[i];
5550 has_keys_.push_back (current_key);
5651 key_start_.push_back (i);
5752 }
5853 }
59- key_end_.push_back (num_keys);
60-
61- largest_key_set_ = std::max (
62- key_end_[num_keys - 1 ] - key_start_[num_keys - 1 ], largest_key_set_);
54+ key_len_.push_back (num_keys - key_start_[key_start_.size ()-1 ]);
6355 }
6456
6557 CHECK_LE (has_keys_.size (), num_keys);
58+ CHECK_EQ (has_keys_.size (), key_start_.size ());
59+ CHECK_EQ (has_keys_.size (), key_len_.size ());
6660
6761 // Resize the tops to match the keys.
6862 vector<int > required_shape (top[0 ]->shape ());
@@ -91,7 +85,7 @@ void KeyPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
9185 pooling_tops.push_back (&key_top);
9286
9387 vector<int > bottom_shape = bottom[0 ]->shape ();
94- bottom_shape[0 ] = key_end_[i] - key_start_ [i];
88+ bottom_shape[0 ] = key_len_ [i];
9589 key_bottom.Reshape (bottom_shape);
9690
9791 // Set the bottom as a view into the alocated blob.
0 commit comments