Skip to content

Commit 8a42a54

Browse files
authored
Merge pull request #915 from reyoung/feature/add_unittest
Add unittest to cover SgdThreadUpdater's enableBufType
2 parents 80b45ad + 7aad9f5 commit 8a42a54

File tree

6 files changed

+58
-2
lines changed

6 files changed

+58
-2
lines changed

paddle/trainer/ThreadParameterUpdater.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
5555
// not create parameter buf for PARAMETER_GRADIENT for sparse update in
5656
// Parameter::enableType(). But gradient parameter buf is still used
5757
// in SgdThreadUpdater. We need to explicitly create it.
58+
//
59+
// The AverageOptimizer::restore/apply method will use PARAMETER_GRADIENT
60+
// as a temp buffer.
5861
para->enableBufType(PARAMETER_GRADIENT);
5962
}
6063
}

paddle/trainer/tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ add_test(NAME test_Trainer
2727
add_unittest_without_exec(test_TrainerOnePass
2828
test_TrainerOnePass.cpp)
2929
add_test(NAME test_TrainerOnePass
30-
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
30+
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d
31+
${PROJ_ROOT}/python/:${PROJ_ROOT}/paddle/trainer/tests
3132
${PROJ_ROOT}/paddle/.set_port.sh -p port ${CMAKE_CURRENT_BINARY_DIR}/test_TrainerOnePass
3233
WORKING_DIRECTORY ${PROJ_ROOT}/paddle/)
3334

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
do_not_matter.txt
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from paddle.trainer_config_helpers import *
2+
3+
settings(batch_size=128, learning_method=AdaGradOptimizer(), learning_rate=1e-4)
4+
5+
file_list = 'trainer/tests/fake_file_list.list'
6+
7+
define_py_data_sources2(
8+
train_list=file_list,
9+
test_list=file_list,
10+
module="simple_sparse_neural_network_dp",
11+
obj="process")
12+
13+
embedding = embedding_layer(
14+
input=data_layer(
15+
name="word_ids", size=65536),
16+
size=128,
17+
param_attr=ParamAttr(sparse_update=True))
18+
prediction = fc_layer(input=embedding, size=10, act=SoftmaxActivation())
19+
20+
outputs(
21+
classification_cost(
22+
input=prediction, label=data_layer(
23+
name='label', size=10)))
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from paddle.trainer.PyDataProvider2 import provider, integer_sequence, integer_value
2+
import random
3+
4+
5+
def init_hook(settings, is_train, **kwargs):
6+
settings.is_train = is_train
7+
8+
9+
@provider(
10+
input_types={'word_ids': integer_value(65536),
11+
'label': integer_value(10)},
12+
min_pool_size=0,
13+
init_hook=init_hook)
14+
def process(settings, filename):
15+
if settings.is_train:
16+
data_size = 2**20
17+
else:
18+
data_size = 2**10
19+
20+
for _ in xrange(data_size):
21+
yield random.randint(0, 65535), random.randint(0, 9)

paddle/trainer/tests/test_TrainerOnePass.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ static const string& configFile1 = "trainer/tests/sample_trainer_config.conf";
2727
static const string& configFile2 =
2828
"trainer/tests/sample_trainer_config_parallel.conf";
2929

30+
static const string& configFileSimpleSparse =
31+
"trainer/tests/simple_sparse_neural_network.py";
32+
3033
DECLARE_bool(use_gpu);
3134
DECLARE_string(config);
3235
DECLARE_int32(gpu_id);
@@ -298,11 +301,15 @@ TEST(checkRemoteUpdater, cpuDeltaTrainerOldUpdater) {
298301
checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true, 10);
299302
}
300303

304+
TEST(SgdThreadUpdater, simpleSparseNN) {
305+
trainerOnePassTest(configFileSimpleSparse, false, false, 1, 0.5, true);
306+
}
307+
301308
int main(int argc, char** argv) {
309+
testing::InitGoogleTest(&argc, argv);
302310
initMain(argc, argv);
303311
initPython(argc, argv);
304312
gNumDevices = hl_get_device_count();
305-
testing::InitGoogleTest(&argc, argv);
306313

307314
FLAGS_num_passes = 1; // train one pass
308315
FLAGS_saving_period = 100000; // do not save parameteres

0 commit comments

Comments
 (0)