Skip to content

Commit d40d28d

Browse files
authored
Merge pull request #6515 from guoshengCS/add-multiBatch-chunkEval
Add ChunkEvaluator for Multi-batches
2 parents 78c20e3 + a7fa205 commit d40d28d

File tree

6 files changed

+151
-31
lines changed

6 files changed

+151
-31
lines changed

paddle/operators/chunk_eval_op.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
3232
"Output(Recall) of ChunkEvalOp should not be null.");
3333
PADDLE_ENFORCE(ctx->HasOutput("F1-Score"),
3434
"Output(F1-Score) of ChunkEvalOp should not be null.");
35+
PADDLE_ENFORCE(ctx->HasOutput("NumInferChunks"),
36+
"Output(NumInferChunks) of ChunkEvalOp should not be null.");
37+
PADDLE_ENFORCE(ctx->HasOutput("NumLabelChunks"),
38+
"Output(NumLabelChunks) of ChunkEvalOp should not be null.");
39+
PADDLE_ENFORCE(
40+
ctx->HasOutput("NumCorrectChunks"),
41+
"Output(NumCorrectChunks) of ChunkEvalOp should not be null.");
3542

3643
auto inference_dim = ctx->GetInputDim("Inference");
3744
auto label_dim = ctx->GetInputDim("Label");
@@ -42,6 +49,9 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
4249
ctx->SetOutputDim("Precision", {1});
4350
ctx->SetOutputDim("Recall", {1});
4451
ctx->SetOutputDim("F1-Score", {1});
52+
ctx->SetOutputDim("NumInferChunks", {1});
53+
ctx->SetOutputDim("NumLabelChunks", {1});
54+
ctx->SetOutputDim("NumCorrectChunks", {1});
4555
}
4656

4757
protected:
@@ -70,6 +80,16 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
7080
"sensitivity) of chunks on the given mini-batch.");
7181
AddOutput("F1-Score",
7282
"(float). The evaluated F1-Score on the given mini-batch.");
83+
AddOutput("NumInferChunks",
84+
"(int64_t). The number of chunks in Inference on the given "
85+
"mini-batch.");
86+
AddOutput(
87+
"NumLabelChunks",
88+
"(int64_t). The number of chunks in Label on the given mini-batch.");
89+
AddOutput(
90+
"NumCorrectChunks",
91+
"(int64_t). The number of chunks both in Inference and Label on the "
92+
"given mini-batch.");
7393
AddAttr<int>("num_chunk_types",
7494
"(int). The number of chunk type. See below for details.");
7595
AddAttr<std::string>(

paddle/operators/chunk_eval_op.h

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
111111
std::vector<Segment> label_segments;
112112
std::vector<Segment> output_segments;
113113
std::set<int> excluded_chunk_types;
114-
int64_t num_output_segments = 0;
115-
int64_t num_label_segments = 0;
116-
int64_t num_correct = 0;
114+
117115
if (context.Attr<std::string>("chunk_scheme") == "IOB") {
118116
num_tag_types = 2;
119117
tag_begin = 0;
@@ -151,12 +149,24 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
151149
auto* precision = context.Output<Tensor>("Precision");
152150
auto* recall = context.Output<Tensor>("Recall");
153151
auto* f1 = context.Output<Tensor>("F1-Score");
152+
auto* num_infer_chunks = context.Output<Tensor>("NumInferChunks");
153+
auto* num_label_chunks = context.Output<Tensor>("NumLabelChunks");
154+
auto* num_correct_chunks = context.Output<Tensor>("NumCorrectChunks");
154155

155156
const int64_t* inference_data = inference->data<int64_t>();
156157
const int64_t* label_data = label->data<int64_t>();
157158
T* precision_data = precision->mutable_data<T>(context.GetPlace());
158159
T* racall_data = recall->mutable_data<T>(context.GetPlace());
159160
T* f1_data = f1->mutable_data<T>(context.GetPlace());
161+
int64_t* num_infer_chunks_data =
162+
num_infer_chunks->mutable_data<int64_t>(context.GetPlace());
163+
int64_t* num_label_chunks_data =
164+
num_label_chunks->mutable_data<int64_t>(context.GetPlace());
165+
int64_t* num_correct_chunks_data =
166+
num_correct_chunks->mutable_data<int64_t>(context.GetPlace());
167+
*num_infer_chunks_data = 0;
168+
*num_label_chunks_data = 0;
169+
*num_correct_chunks_data = 0;
160170

161171
auto lod = label->lod();
162172
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
@@ -166,17 +176,23 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
166176
for (int i = 0; i < num_sequences; ++i) {
167177
int seq_length = lod[0][i + 1] - lod[0][i];
168178
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
169-
output_segments, label_segments, num_output_segments,
170-
num_label_segments, num_correct, num_chunk_types,
171-
num_tag_types, other_chunk_type, tag_begin, tag_inside,
172-
tag_end, tag_single, excluded_chunk_types);
179+
output_segments, label_segments, *num_infer_chunks_data,
180+
*num_label_chunks_data, *num_correct_chunks_data,
181+
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
182+
tag_inside, tag_end, tag_single, excluded_chunk_types);
173183
}
174-
*precision_data = !num_output_segments ? 0 : static_cast<T>(num_correct) /
175-
num_output_segments;
176-
*racall_data = !num_label_segments ? 0 : static_cast<T>(num_correct) /
177-
num_label_segments;
178-
*f1_data = !num_correct ? 0 : 2 * (*precision_data) * (*racall_data) /
179-
((*precision_data) + (*racall_data));
184+
*precision_data = !(*num_infer_chunks_data)
185+
? 0
186+
: static_cast<T>(*num_correct_chunks_data) /
187+
(*num_infer_chunks_data);
188+
*racall_data = !(*num_label_chunks_data)
189+
? 0
190+
: static_cast<T>(*num_correct_chunks_data) /
191+
(*num_label_chunks_data);
192+
*f1_data = !(*num_correct_chunks_data)
193+
? 0
194+
: 2 * (*precision_data) * (*racall_data) /
195+
((*precision_data) + (*racall_data));
180196
}
181197

182198
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,

python/paddle/v2/fluid/evaluator.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from framework import Program, unique_name, Variable
55
from layer_helper import LayerHelper
66

7-
__all__ = ['Accuracy']
7+
__all__ = ['Accuracy', 'ChunkEvaluator']
88

99

1010
def _clone_var_(block, var):
@@ -132,3 +132,74 @@ def eval(self, executor, eval_program=None):
132132
correct = layers.cast(correct, dtype='float32', **kwargs)
133133
out = layers.elementwise_div(x=correct, y=total, **kwargs)
134134
return np.array(executor.run(eval_program, fetch_list=[out])[0])
135+
136+
137+
class ChunkEvaluator(Evaluator):
138+
"""
139+
Accumulate counter numbers output by chunk_eval from mini-batches and
140+
compute the precision recall and F1-score using the accumulated counter
141+
numbers.
142+
"""
143+
144+
def __init__(self,
145+
input,
146+
label,
147+
chunk_scheme,
148+
num_chunk_types,
149+
excluded_chunk_types=None,
150+
**kwargs):
151+
super(ChunkEvaluator, self).__init__("chunk_eval", **kwargs)
152+
main_program = self.helper.main_program
153+
if main_program.current_block().idx != 0:
154+
raise ValueError("You can only invoke Evaluator in root block")
155+
156+
self.num_infer_chunks = self.create_state(
157+
dtype='int64', shape=[1], suffix='num_infer_chunks')
158+
self.num_label_chunks = self.create_state(
159+
dtype='int64', shape=[1], suffix='num_label_chunks')
160+
self.num_correct_chunks = self.create_state(
161+
dtype='int64', shape=[1], suffix='num_correct_chunks')
162+
kwargs = {'main_program': main_program}
163+
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval(
164+
input=input,
165+
label=label,
166+
chunk_scheme=chunk_scheme,
167+
num_chunk_types=num_chunk_types,
168+
excluded_chunk_types=excluded_chunk_types,
169+
**kwargs)
170+
layers.sums(
171+
input=[self.num_infer_chunks, num_infer_chunks],
172+
out=self.num_infer_chunks,
173+
**kwargs)
174+
layers.sums(
175+
input=[self.num_label_chunks, num_label_chunks],
176+
out=self.num_label_chunks,
177+
**kwargs)
178+
layers.sums(
179+
input=[self.num_correct_chunks, num_correct_chunks],
180+
out=self.num_correct_chunks,
181+
**kwargs)
182+
183+
self.metrics.extend([precision, recall, f1_score])
184+
185+
def eval(self, executor, eval_program=None):
186+
if eval_program is None:
187+
eval_program = Program()
188+
block = eval_program.current_block()
189+
kwargs = {'main_program': eval_program}
190+
num_infer_chunks, num_label_chunks, num_correct_chunks = executor.run(
191+
eval_program,
192+
fetch_list=[_clone_var_(block, state) for state in self.states])
193+
num_infer_chunks = num_infer_chunks[0]
194+
num_label_chunks = num_label_chunks[0]
195+
num_correct_chunks = num_correct_chunks[0]
196+
precision = float(
197+
num_correct_chunks) / num_infer_chunks if num_infer_chunks else 0
198+
recall = float(
199+
num_correct_chunks) / num_label_chunks if num_label_chunks else 0
200+
f1_score = float(2 * precision * recall) / (
201+
precision + recall) if num_correct_chunks else 0
202+
return np.array(
203+
[precision], dtype='float32'), np.array(
204+
[recall], dtype='float32'), np.array(
205+
[f1_score], dtype='float32')

python/paddle/v2/fluid/layers/nn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,15 +392,18 @@ def chunk_eval(input,
392392
excluded_chunk_types=None,
393393
**kwargs):
394394
"""
395-
This function computes the accuracy using the input and label.
396-
The output is the top_k inputs and their indices.
395+
This function computes and outputs the precision, recall and
396+
F1-score of chunk detection.
397397
"""
398398
helper = LayerHelper("chunk_eval", **kwargs)
399399

400400
# prepare output
401401
precision = helper.create_tmp_variable(dtype="float32")
402402
recall = helper.create_tmp_variable(dtype="float32")
403403
f1_score = helper.create_tmp_variable(dtype="float32")
404+
num_infer_chunks = helper.create_tmp_variable(dtype="int64")
405+
num_label_chunks = helper.create_tmp_variable(dtype="int64")
406+
num_correct_chunks = helper.create_tmp_variable(dtype="int64")
404407

405408
helper.append_op(
406409
type="chunk_eval",
@@ -409,14 +412,17 @@ def chunk_eval(input,
409412
outputs={
410413
"Precision": [precision],
411414
"Recall": [recall],
412-
"F1-Score": [f1_score]
415+
"F1-Score": [f1_score],
416+
"NumInferChunks": [num_infer_chunks],
417+
"NumLabelChunks": [num_label_chunks],
418+
"NumCorrectChunks": [num_correct_chunks]
413419
},
414420
attrs={
415421
"num_chunk_types": num_chunk_types,
416422
'chunk_scheme': chunk_scheme,
417423
'excluded_chunk_types': excluded_chunk_types or []
418424
})
419-
return precision, recall, f1_score
425+
return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks
420426

421427

422428
def sequence_conv(input,

python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def main():
150150
crf_decode = fluid.layers.crf_decoding(
151151
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
152152

153-
precision, recall, f1_score = fluid.layers.chunk_eval(
153+
chunk_evaluator = fluid.evaluator.ChunkEvaluator(
154154
input=crf_decode,
155155
label=target,
156156
chunk_scheme="IOB",
@@ -176,20 +176,21 @@ def main():
176176

177177
batch_id = 0
178178
for pass_id in xrange(PASS_NUM):
179+
chunk_evaluator.reset(exe)
179180
for data in train_data():
180-
outs = exe.run(fluid.default_main_program(),
181-
feed=feeder.feed(data),
182-
fetch_list=[avg_cost, precision, recall, f1_score])
183-
avg_cost_val = np.array(outs[0])
184-
precision_val = np.array(outs[1])
185-
recall_val = np.array(outs[2])
186-
f1_score_val = np.array(outs[3])
181+
cost, precision, recall, f1_score = exe.run(
182+
fluid.default_main_program(),
183+
feed=feeder.feed(data),
184+
fetch_list=[avg_cost] + chunk_evaluator.metrics)
185+
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(
186+
exe)
187187

188188
if batch_id % 10 == 0:
189-
print("avg_cost=" + str(avg_cost_val))
190-
print("precision_val=" + str(precision_val))
191-
print("recall_val:" + str(recall_val))
192-
print("f1_score_val:" + str(f1_score_val))
189+
print("avg_cost:" + str(cost) + " precision:" + str(
190+
precision) + " recall:" + str(recall) + " f1_score:" + str(
191+
f1_score) + " pass_precision:" + str(
192+
pass_precision) + " pass_recall:" + str(pass_recall)
193+
+ " pass_f1_score:" + str(pass_f1_score))
193194

194195
# exit early for CI
195196
exit(0)

python/paddle/v2/fluid/tests/test_chunk_eval_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,13 @@ def set_data(self):
147147
'Recall': np.asarray(
148148
[recall], dtype='float32'),
149149
'F1-Score': np.asarray(
150-
[f1], dtype='float32')
150+
[f1], dtype='float32'),
151+
'NumInferChunks': np.asarray(
152+
[self.num_infer_chunks], dtype='int64'),
153+
'NumLabelChunks': np.asarray(
154+
[self.num_label_chunks], dtype='int64'),
155+
'NumCorrectChunks': np.asarray(
156+
[self.num_correct_chunks], dtype='int64')
151157
}
152158

153159
def setUp(self):

0 commit comments

Comments
 (0)