Skip to content

Commit 2965df5

Browse files
authored
Merge pull request #960 from pengli09/chunk_evaluator
Add excluded_chunk_types to ChunkEvaluator
2 parents 8a42a54 + 6e405a1 commit 2965df5

File tree

4 files changed

+50
-22
lines changed

4 files changed

+50
-22
lines changed

paddle/gserver/evaluators/ChunkEvaluator.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <set>
1516
#include <vector>
1617

1718
#include "paddle/math/Vector.h"
@@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
7273

7374
std::vector<Segment> labelSegments_;
7475
std::vector<Segment> outputSegments_;
76+
std::set<int> excludedChunkTypes_;
7577

7678
public:
7779
virtual void init(const EvaluatorConfig& config) {
@@ -105,6 +107,10 @@ class ChunkEvaluator : public Evaluator {
105107
}
106108
CHECK(config.has_num_chunk_types()) << "Missing num_chunk_types in config";
107109
otherChunkType_ = numChunkTypes_ = config.num_chunk_types();
110+
111+
// the chunks of types in excludedChunkTypes_ will not be counted
112+
auto& tmp = config.excluded_chunk_types();
113+
excludedChunkTypes_.insert(tmp.begin(), tmp.end());
108114
}
109115

110116
virtual void start() {
@@ -156,7 +162,8 @@ class ChunkEvaluator : public Evaluator {
156162
getSegments(label, length, labelSegments_);
157163
size_t i = 0, j = 0;
158164
while (i < outputSegments_.size() && j < labelSegments_.size()) {
159-
if (outputSegments_[i] == labelSegments_[j]) {
165+
if (outputSegments_[i] == labelSegments_[j] &&
166+
excludedChunkTypes_.count(outputSegments_[i].type) != 1) {
160167
++numCorrect_;
161168
}
162169
if (outputSegments_[i].end < labelSegments_[j].end) {
@@ -168,8 +175,12 @@ class ChunkEvaluator : public Evaluator {
168175
++j;
169176
}
170177
}
171-
numLabelSegments_ += labelSegments_.size();
172-
numOutputSegments_ += outputSegments_.size();
178+
for (auto& segment : labelSegments_) {
179+
if (excludedChunkTypes_.count(segment.type) != 1) ++numLabelSegments_;
180+
}
181+
for (auto& segment : outputSegments_) {
182+
if (excludedChunkTypes_.count(segment.type) != 1) ++numOutputSegments_;
183+
}
173184
}
174185

175186
void getSegments(int* label, int length, std::vector<Segment>& segments) {

proto/ModelConfig.proto

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,10 @@ message EvaluatorConfig {
433433
repeated string input_layers = 3;
434434

435435
// Used by ChunkEvaluator
436-
optional string chunk_scheme = 4; // one of "IOB", "IOE", "IOBES"
437-
optional int32 num_chunk_types = 5; // number of chunk types other than "other"
436+
// one of "IOB", "IOE", "IOBES"
437+
optional string chunk_scheme = 4;
438+
// number of chunk types other than "other"
439+
optional int32 num_chunk_types = 5;
438440

439441
// Used by PrecisionRecallEvaluator and ClassificationErrorEvaluator
440442
// For multi binary labels: true if output > classification_threshold
@@ -453,6 +455,10 @@ message EvaluatorConfig {
453455

454456
// whether to delimit the sequence in the seq_text_printer
455457
optional bool delimited = 11 [default = true];
458+
459+
// Used by ChunkEvaluator
460+
// chunk of these types are not counted
461+
repeated int32 excluded_chunk_types = 12;
456462
}
457463

458464
message LinkConfig {

python/paddle/trainer/config_parser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,8 @@ def Evaluator(
12401240
dict_file=None,
12411241
result_file=None,
12421242
num_results=None,
1243-
delimited=None, ):
1243+
delimited=None,
1244+
excluded_chunk_types=None, ):
12441245
evaluator = g_config.model_config.evaluators.add()
12451246
evaluator.type = type
12461247
evaluator.name = MakeLayerNameInSubmodel(name)
@@ -1269,6 +1270,9 @@ def Evaluator(
12691270
if delimited is not None:
12701271
evaluator.delimited = delimited
12711272

1273+
if excluded_chunk_types:
1274+
evaluator.excluded_chunk_types.extend(excluded_chunk_types)
1275+
12721276

12731277
class LayerBase(object):
12741278
def __init__(

python/paddle/trainer_config_helpers/evaluators.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,21 @@ def impl(method):
5757
return impl
5858

5959

60-
def evaluator_base(input,
61-
type,
62-
label=None,
63-
weight=None,
64-
name=None,
65-
chunk_scheme=None,
66-
num_chunk_types=None,
67-
classification_threshold=None,
68-
positive_label=None,
69-
dict_file=None,
70-
result_file=None,
71-
num_results=None,
72-
delimited=None):
60+
def evaluator_base(
61+
input,
62+
type,
63+
label=None,
64+
weight=None,
65+
name=None,
66+
chunk_scheme=None,
67+
num_chunk_types=None,
68+
classification_threshold=None,
69+
positive_label=None,
70+
dict_file=None,
71+
result_file=None,
72+
num_results=None,
73+
delimited=None,
74+
excluded_chunk_types=None, ):
7375
"""
7476
Evaluator will evaluate the network status while training/testing.
7577
@@ -127,7 +129,8 @@ def evaluator_base(input,
127129
positive_label=positive_label,
128130
dict_file=dict_file,
129131
result_file=result_file,
130-
delimited=delimited)
132+
delimited=delimited,
133+
excluded_chunk_types=excluded_chunk_types, )
131134

132135

133136
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
@@ -330,7 +333,8 @@ def chunk_evaluator(
330333
label,
331334
chunk_scheme,
332335
num_chunk_types,
333-
name=None, ):
336+
name=None,
337+
excluded_chunk_types=None, ):
334338
"""
335339
Chunk evaluator is used to evaluate segment labelling accuracy for a
336340
sequence. It calculates the chunk detection F1 score.
@@ -376,14 +380,17 @@ def chunk_evaluator(
376380
:param num_chunk_types: number of chunk types other than "other"
377381
:param name: The Evaluator name, it is optional.
378382
:type name: basename|None
383+
:param excluded_chunk_types: chunks of these types are not considered
384+
:type excluded_chunk_types: list of integer|None
379385
"""
380386
evaluator_base(
381387
name=name,
382388
type="chunk",
383389
input=input,
384390
label=label,
385391
chunk_scheme=chunk_scheme,
386-
num_chunk_types=num_chunk_types)
392+
num_chunk_types=num_chunk_types,
393+
excluded_chunk_types=excluded_chunk_types, )
387394

388395

389396
@evaluator(EvaluatorAttribute.FOR_UTILS)

0 commit comments

Comments
 (0)