Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
100796f
add 'text_classification_cnn.py' and 'text_classification_dnn.py'
JiayiFeng May 2, 2017
dfc100a
Merge pull request #1 from Canpio/dev
JiayiFeng May 2, 2017
a651dc4
remove 'import sqlite3'
JiayiFeng May 2, 2017
31ef20f
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
JiayiFeng May 3, 2017
046d0a1
update readme file
JiayiFeng May 4, 2017
868d4cb
update readme file
JiayiFeng May 4, 2017
dd827a2
update README.md and add dnn_net.png
JiayiFeng May 4, 2017
20c2ec9
update README.md
JiayiFeng May 4, 2017
77abc3c
change dnn_net.png
JiayiFeng May 4, 2017
0619165
update dnn_net.png
JiayiFeng May 4, 2017
da0eb4e
add .gitignore
JiayiFeng May 4, 2017
5a4ad99
update README.md
JiayiFeng May 5, 2017
0a39b52
update README.md
JiayiFeng May 5, 2017
1b01f6a
update README.md
JiayiFeng May 5, 2017
78531d8
update README.md
JiayiFeng May 5, 2017
55dc1a3
add cnn_net.png and update README.md
JiayiFeng May 5, 2017
ac67d5a
add commits in cnn net
JiayiFeng May 5, 2017
947e70d
add comments in cnn net
JiayiFeng May 5, 2017
763615e
finish README.md
JiayiFeng May 9, 2017
304df71
finish README.md
JiayiFeng May 9, 2017
225979c
fix problems finded in review
JiayiFeng May 9, 2017
3a712e0
change README.md
JiayiFeng May 9, 2017
f18b572
add and test auc evaluator
JiayiFeng May 10, 2017
ae6a883
Merge pull request #2 from Canpio/dev
JiayiFeng May 10, 2017
18a4bd9
remove copyright statement
JiayiFeng May 10, 2017
344a9a4
add section of 'self-define data reader' into README.md
JiayiFeng May 12, 2017
5a59fcd
update README.md
JiayiFeng May 12, 2017
55a88ad
update README.md and fine-tune network hyper-params to avoid over-fit…
JiayiFeng May 12, 2017
b0071ff
update README.md
JiayiFeng May 15, 2017
631ffdd
remove .gitignore
JiayiFeng May 16, 2017
e6cfcf1
update README.md and renew dnn_net.png
JiayiFeng May 16, 2017
f46dcb1
read->red
JiayiFeng May 17, 2017
6ff3565
read->red
JiayiFeng May 17, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions text_classification/text_classification_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import paddle.v2 as paddle
import gzip


def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128):
data = paddle.layer.data("word",
paddle.data_type.integer_value_sequence(input_dim))
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2))

emb = paddle.layer.embedding(input=data, size=emb_dim)
conv_3 = paddle.networks.sequence_conv_pool(
input=emb, context_len=3, hidden_size=hid_dim)
conv_4 = paddle.networks.sequence_conv_pool(
input=emb, context_len=4, hidden_size=hid_dim)
output = paddle.layer.fc(
input=[conv_3, conv_4], size=class_dim, act=paddle.activation.Softmax())

cost = paddle.layer.classification_cost(input=output, label=lbl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help with adding more than one evaluators to this configuration, for example, besides an evaluator to calculate the error rate, if it possible to add a precision-recall evaluator? I hope to test a configuration with more than one evaluators. Thanks for your work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added auc evaluator


return cost, output


def train_cnn_model(num_pass):
# load word dictionary
print 'load dictionary...'
word_dict = paddle.dataset.imdb.word_dict()

dict_dim = len(word_dict)
class_dim = 2
# define data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=100)
test_reader = paddle.batch(
lambda: paddle.dataset.imdb.test(word_dict), batch_size=100)

# network config
[cost, _] = convolution_net(dict_dim, class_dim=class_dim)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5))

# create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=adam_optimizer)

# Define end batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
with gzip.open("cnn_params.tar.gz", 'w') as f:
Copy link
Collaborator

@lcy-seso lcy-seso May 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better that models are saved according to the pass number, otherwise, new models will overwrite the old model, as a result, it is impossible to do the model selection.

parameters.to_tar(f)

# begin training network
feeding = {'word': 0, 'label': 1}
trainer.train(
reader=train_reader,
event_handler=event_handler,
feeding=feeding,
num_passes=num_pass)

print("Training finished.")


def cnn_infer():
print("Begin to predict...")

word_dict = paddle.dataset.imdb.word_dict()
dict_dim = len(word_dict)
class_dim = 2

[_, output] = convolution_net(dict_dim, class_dim=class_dim)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open("cnn_params.tar.gz"))

infer_data = []
infer_label_data = []
infer_data_num = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constraint on the number of test data can be removed.

for item in paddle.dataset.imdb.test(word_dict):
infer_data.append([item[0]])
infer_label_data.append(item[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meaning is more unambiguous if rename infer_label_data to infer_data_label.

if len(infer_data) == infer_data_num:
break

predictions = paddle.infer(
output_layer=output,
parameters=parameters,
input=infer_data,
field=['value'])
for i, prob in enumerate(predictions):
print prob, infer_label_data[i]


if __name__ == "__main__":
paddle.init(use_gpu=False, trainer_count=10)
train_cnn_model(num_pass=10)
cnn_infer()
138 changes: 138 additions & 0 deletions text_classification/text_classification_dnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import paddle.v2 as paddle
import gzip


def fc_net(input_dim, class_dim=2, emb_dim=256):
data = paddle.layer.data("word",
paddle.data_type.integer_value_sequence(input_dim))
lbl = paddle.layer.data("label", paddle.data_type.integer_value(class_dim))

emb = paddle.layer.embedding(input=data, size=emb_dim)
seq_pool = paddle.layer.pooling(
input=emb, pooling_type=paddle.pooling.Max())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is max pooling better than average pooling or sum?
p.s. the old configuration file may not always be a good network configuration.


hd1 = paddle.layer.fc(
input=seq_pool,
size=128,
act=paddle.activation.Tanh(),
param_attr=paddle.attr.Param(initial_std=0.01))
hd2 = paddle.layer.fc(
input=hd1,
size=32,
act=paddle.activation.Tanh(),
param_attr=paddle.attr.Param(initial_std=0.01))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does initializing the parameter matrix according to layer size work better?


output = paddle.layer.fc(
input=hd2,
size=class_dim,
act=paddle.activation.Softmax(),
param_attr=paddle.attr.Param(initial_std=0.1))

cost = paddle.layer.classification_cost(input=output, label=lbl)

return cost, output


def train_dnn_model(num_pass):
# load word dictionary
print 'load dictionary...'
word_dict = paddle.dataset.imdb.word_dict()

dict_dim = len(word_dict)
class_dim = 2
# define data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=100)
test_reader = paddle.batch(
lambda: paddle.dataset.imdb.test(word_dict), batch_size=100)

# network config
[cost, _] = fc_net(dict_dim, class_dim=class_dim)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5))

# create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=adam_optimizer)

# Define end batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
with gzip.open("dnn_params.tar.gz", 'w') as f:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to save the model according to the pass number.

parameters.to_tar(f)

# begin training network
feeding = {'word': 0, 'label': 1}
trainer.train(
reader=train_reader,
event_handler=event_handler,
feeding=feeding,
num_passes=num_pass)

print("Training finished.")


def dnn_infer():
print("Begin to predict...")

word_dict = paddle.dataset.imdb.word_dict()
dict_dim = len(word_dict)
class_dim = 2

[_, output] = fc_net(dict_dim, class_dim=class_dim)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open("dnn_params.tar.gz"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to save the model according to the pass number.


infer_data = []
infer_label_data = []
infer_data_num = 100
for item in paddle.dataset.imdb.test(word_dict):
infer_data.append([item[0]])
infer_label_data.append(item[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meaning is more unambiguous if rename infer_label_data to infer_data_label.

if len(infer_data) == infer_data_num:
break

predictions = paddle.infer(
output_layer=output,
parameters=parameters,
input=infer_data,
field=['value'])
for i, prob in enumerate(predictions):
print prob, infer_label_data[i]


if __name__ == "__main__":
paddle.init(use_gpu=False, trainer_count=4)
train_dnn_model(2)
dnn_infer()