Skip to content

Commit d6edcc0

Browse files
author
Evan Lezar
committed
Store the length of a key run instead of the end index.
1 parent 282c601 commit d6edcc0

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

include/caffe/layers/key_pooling_layer.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@ class KeyPoolingLayer : public Layer<Dtype> {
5252
vector<Dtype> has_keys_;
5353
// Store the start and end indices for each key in the input array.
5454
vector<int> key_start_;
55-
vector<int> key_end_;
56-
57-
int largest_key_set_;
58-
55+
vector<int> key_len_;
5956

6057
};
6158

src/caffe/layers/key_pooling_layer.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void KeyPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
2525
int num_keys = bottom[1]->shape(0);
2626
has_keys_.clear();
2727
key_start_.clear();
28-
key_end_.clear();
28+
key_len_.clear();
2929

3030
vector<Blob<Dtype>*> pooling_top;
3131
pooling_top.push_back(top[0]);
@@ -42,27 +42,21 @@ void KeyPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
4242
has_keys_.push_back(current_key);
4343
key_start_.push_back(0);
4444

45-
int j = 0;
4645
for (int i = 1; i < num_keys; ++i) {
4746
if (keys[i] != current_key) {
48-
j++;
49-
key_end_.push_back(i);
50-
51-
largest_key_set_ =
52-
std::max(key_end_[j - 1] - key_start_[j - 1], largest_key_set_);
47+
key_len_.push_back(i - key_start_[key_start_.size()-1]);
5348

5449
current_key = keys[i];
5550
has_keys_.push_back(current_key);
5651
key_start_.push_back(i);
5752
}
5853
}
59-
key_end_.push_back(num_keys);
60-
61-
largest_key_set_ = std::max(
62-
key_end_[num_keys - 1] - key_start_[num_keys - 1], largest_key_set_);
54+
key_len_.push_back(num_keys - key_start_[key_start_.size()-1]);
6355
}
6456

6557
CHECK_LE(has_keys_.size(), num_keys);
58+
CHECK_EQ(has_keys_.size(), key_start_.size());
59+
CHECK_EQ(has_keys_.size(), key_len_.size());
6660

6761
// Resize the tops to match the keys.
6862
vector<int> required_shape(top[0]->shape());
@@ -91,7 +85,7 @@ void KeyPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
9185
pooling_tops.push_back(&key_top);
9286

9387
vector<int> bottom_shape = bottom[0]->shape();
94-
bottom_shape[0] = key_end_[i] - key_start_[i];
88+
bottom_shape[0] = key_len_[i];
9589
key_bottom.Reshape(bottom_shape);
9690

9791
// Set the bottom as a view into the alocated blob.

0 commit comments

Comments
 (0)