Skip to content

Commit f884a61

Browse files
committed
rewrite the text classification demo.
1 parent 4f0d8ac commit f884a61

File tree

14 files changed

+764
-694
lines changed

14 files changed

+764
-694
lines changed

image_classification/reader.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License
14-
151
import random
162
from paddle.v2.image import load_and_transform
173

image_classification/train.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License
14-
151
import gzip
162

173
import paddle.v2 as paddle
@@ -51,10 +37,10 @@ def main():
5137
learning_rate_schedule="discexp", )
5238

5339
train_reader = paddle.batch(
54-
paddle.reader.shuffle(reader.test_reader("train.list"), buf_size=1000),
40+
paddle.reader.shuffle(reader.train_reader("train.list"), buf_size=1000),
5541
batch_size=BATCH_SIZE)
5642
test_reader = paddle.batch(
57-
reader.train_reader("test.list"), batch_size=BATCH_SIZE)
43+
reader.test_reader("test.list"), batch_size=BATCH_SIZE)
5844

5945
# End batch and end pass event handler
6046
def event_handler(event):

image_classification/vgg.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
151
import paddle.v2 as paddle
162

173
__all__ = ['vgg13', 'vgg16', 'vgg19']

text_classification/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
data
2+
*.log
3+
*.pyc

text_classification/README.md

Lines changed: 125 additions & 172 deletions
Large diffs are not rendered by default.

text_classification/index.html

Lines changed: 125 additions & 172 deletions
Large diffs are not rendered by default.

text_classification/infer.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
import sys
4+
import os
5+
import gzip
6+
7+
import paddle.v2 as paddle
8+
from paddle.v2.layer import parse_network
9+
10+
import network_conf
11+
import reader
12+
from utils import *
13+
14+
15+
def infer(topology, data_dir, word_dict_path, model_path, batch_size=50):
16+
def _infer_a_batch(inferer, test_batch):
17+
probs = inferer.infer(input=test_batch, field=['value'])
18+
for i, prob in enumerate(probs):
19+
print(prob)
20+
21+
print("Begin to predict...")
22+
use_default_data = (data_dir is None)
23+
24+
if use_default_data:
25+
word_dict = paddle.dataset.imdb.word_dict()
26+
test_reader = paddle.dataset.imdb.test(word_dict)
27+
else:
28+
assert os.path.exists(
29+
word_dict_path), "word dictionary file does not exist"
30+
word_dict = load_dict(word_dict_path)
31+
test_reader = reader.test_reader(data_dir, word_dict)()
32+
33+
dict_dim = len(word_dict)
34+
prob = topology(dict_dim, class_num=6, is_infer=True)
35+
36+
# initialize PaddlePaddle
37+
paddle.init(use_gpu=False, trainer_count=1)
38+
39+
# load the trained models
40+
parameters = paddle.parameters.Parameters.from_tar(
41+
gzip.open(model_path, "r"))
42+
inferer = paddle.inference.Inference(
43+
output_layer=prob, parameters=parameters)
44+
45+
test_batch = []
46+
for idx, item in enumerate(test_reader):
47+
test_batch.append([item[0]])
48+
if idx and (not (idx + 1) % batch_size):
49+
_infer_a_batch(inferer, test_batch)
50+
test_batch = []
51+
52+
infer_a_batch(inferer, test_data)
53+
test_batch = []
54+
55+
56+
if __name__ == "__main__":
57+
model_path = "dnn_params_pass_00000.tar.gz"
58+
test_dir = None
59+
word_dict = None
60+
nn_type = "dnn"
61+
class_num = 2
62+
63+
if nn_type == "dnn":
64+
topology = network_conf.fc_net
65+
elif nn_type == "cnn":
66+
topology = network_conf.convolution_net
67+
68+
infer(
69+
topology=topology,
70+
data_dir=test_dir,
71+
word_dict_path=word_dict,
72+
model_path=model_path)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import sys
2+
import math
3+
import gzip
4+
5+
from paddle.v2.layer import parse_network
6+
import paddle.v2 as paddle
7+
8+
__all__ = ["fc_net", "convolution_net"]
9+
10+
11+
def fc_net(dict_dim,
12+
class_num,
13+
emb_dim=28,
14+
hidden_layer_sizes=[28, 8],
15+
is_infer=False):
16+
"""
17+
define the topology of the dnn network
18+
19+
:param dict_dim: size of word dictionary
20+
:type input_dim: int
21+
:params class_num: number of instance class
22+
:type class_num: int
23+
:params emb_dim: embedding vector dimension
24+
:type emb_dim: int
25+
"""
26+
27+
# define the input layers
28+
data = paddle.layer.data("word",
29+
paddle.data_type.integer_value_sequence(dict_dim))
30+
if not is_infer:
31+
lbl = paddle.layer.data("label",
32+
paddle.data_type.integer_value(class_num))
33+
34+
# define the embedding layer
35+
emb = paddle.layer.embedding(input=data, size=emb_dim)
36+
# max pooling to reduce the input sequence into a vector (non-sequence)
37+
seq_pool = paddle.layer.pooling(
38+
input=emb, pooling_type=paddle.pooling.Max())
39+
40+
for idx, hidden_size in enumerate(hidden_layer_sizes):
41+
hidden_init_std = 1.0 / math.sqrt(hidden_size)
42+
hidden = paddle.layer.fc(
43+
input=hidden if idx else seq_pool,
44+
size=hidden_size,
45+
act=paddle.activation.Tanh(),
46+
param_attr=paddle.attr.Param(initial_std=hidden_init_std))
47+
48+
prob = paddle.layer.fc(
49+
input=hidden,
50+
size=class_num,
51+
act=paddle.activation.Softmax(),
52+
param_attr=paddle.attr.Param(initial_std=1.0 / math.sqrt(class_num)))
53+
54+
if is_infer:
55+
return prob
56+
else:
57+
return paddle.layer.classification_cost(
58+
input=prob, label=lbl), prob, lbl
59+
60+
61+
def convolution_net(dict_dim, class_dim=2, emb_dim=28, hid_dim=128):
62+
"""
63+
cnn network definition
64+
65+
:param dict_dim: size of word dictionary
66+
:type input_dim: int
67+
:params class_dim: number of instance class
68+
:type class_dim: int
69+
:params emb_dim: embedding vector dimension
70+
:type emb_dim: int
71+
:params hid_dim: number of same size convolution kernels
72+
:type hid_dim: int
73+
"""
74+
75+
# input layers
76+
data = paddle.layer.data("word",
77+
paddle.data_type.integer_value_sequence(dict_dim))
78+
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2))
79+
80+
#embedding layer
81+
emb = paddle.layer.embedding(input=data, size=emb_dim)
82+
83+
# convolution layers with max pooling
84+
conv_3 = paddle.networks.sequence_conv_pool(
85+
input=emb, context_len=3, hidden_size=hid_dim)
86+
conv_4 = paddle.networks.sequence_conv_pool(
87+
input=emb, context_len=4, hidden_size=hid_dim)
88+
89+
# fc and output layer
90+
output = paddle.layer.fc(
91+
input=[conv_3, conv_4], size=class_dim, act=paddle.activation.Softmax())
92+
93+
cost = paddle.layer.classification_cost(input=output, label=lbl)
94+
95+
return cost, output, lbl

text_classification/reader.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
import os
4+
5+
6+
def train_reader(data_dir, word_dict, label_dict):
7+
"""
8+
Reader interface for training data
9+
10+
:param data_dir: data directory
11+
:type data_dir: str
12+
:param word_dict: path of word dictionary,
13+
the dictionary must has a "UNK" in it.
14+
:type word_dict: Python dict
15+
:param label_dict: path of label dictionary
16+
:type label_dict: Python dict
17+
"""
18+
19+
def reader():
20+
UNK_ID = word_dict["<UNK>"]
21+
word_col = 1
22+
lbl_col = 0
23+
24+
for file_name in os.listdir(data_dir):
25+
with open(os.path.join(data_dir, file_name), "r") as f:
26+
for line in f:
27+
line_split = line.strip().split("\t")
28+
word_ids = [
29+
word_dict.get(w, UNK_ID)
30+
for w in line_split[word_col].split()
31+
]
32+
yield word_ids, label_dict[line_split[lbl_col]]
33+
34+
return reader
35+
36+
37+
def test_reader(data_dir, word_dict):
38+
"""
39+
Reader interface for testing data
40+
41+
:param data_dir: data directory.
42+
:type data_dir: str
43+
:param word_dict: path of word dictionary,
44+
the dictionary must has a "UNK" in it.
45+
:type word_dict: Python dict
46+
"""
47+
48+
def reader():
49+
UNK_ID = word_dict["<UNK>"]
50+
word_col = 1
51+
52+
for file_name in os.listdir(data_dir):
53+
with open(os.path.join(data_dir, file_name), "r") as f:
54+
for line in f:
55+
line_split = line.strip().split("\t")
56+
if len(line_split) < word_col: continue
57+
word_ids = [
58+
word_dict.get(w, UNK_ID)
59+
for w in line_split[word_col].split()
60+
]
61+
yield word_ids, line_split[word_col]
62+
63+
return reader

text_classification/run.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/sh
2+
3+
python train.py \
4+
--nn_type="dnn" \
5+
--batch_size=64 \
6+
--num_passes=10 \
7+
2>&1 | tee train.log

0 commit comments

Comments
 (0)