Skip to content

Commit cbddad3

Browse files
committed
Try add unittest for sgd local updater
1 parent b0c6331 commit cbddad3

9 files changed

Lines changed: 149 additions & 7 deletions

File tree

paddle/trainer/ThreadParameterUpdater.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,17 @@ void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
5151
size_t numRows = para->isGradSparseUpdate() ? para->getConfig().dims(0) : 0;
5252
optimizers_[pid]->init(numRows, &para->getConfig());
5353
if (para->isGradSparseUpdate() && FLAGS_trainer_count == 1) {
54-
// For trainer_count=1, the gradient machine is NeuralNetwork, which does
55-
// not create parameter buf for PARAMETER_GRADIENT for sparse update in
56-
// Parameter::enableType(). But gradient parameter buf is still used
57-
// in SgdThreadUpdater. We need to explicitly create it.
58-
para->enableBufType(PARAMETER_GRADIENT);
54+
LOG(INFO) << "I'm here";
55+
// // For trainer_count=1, the gradient machine is NeuralNetwork,
56+
// which
57+
// does
58+
// // not create parameter buf for PARAMETER_GRADIENT for sparse
59+
// update
60+
// in
61+
// // Parameter::enableType(). But gradient parameter buf is still
62+
// used
63+
// // in SgdThreadUpdater. We need to explicitly create it.
64+
// para->enableBufType(PARAMETER_GRADIENT);
5965
}
6066
}
6167
}

paddle/trainer/TrainerConfigHelper.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,12 @@ std::shared_ptr<TrainerConfigHelper> TrainerConfigHelper::createFromFlags() {
193193
std::shared_ptr<TrainerConfigHelper>
194194
TrainerConfigHelper::createFromFlagConfig() {
195195
CHECK(!FLAGS_config.empty());
196-
return std::make_shared<TrainerConfigHelper>(FLAGS_config);
196+
return create(FLAGS_config);
197+
}
198+
199+
std::shared_ptr<TrainerConfigHelper> TrainerConfigHelper::create(
200+
const std::string &configFilename) {
201+
return std::make_shared<TrainerConfigHelper>(configFilename);
197202
}
198203

199204
} // namespace paddle

paddle/trainer/TrainerConfigHelper.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ class TrainerConfigHelper /*final*/ {
193193
*/
194194
static std::shared_ptr<TrainerConfigHelper> createFromFlagConfig();
195195

196+
/**
197+
* @brief Create TrainerConfigHelper from configi file.
198+
* @param configFilename config file path.
199+
* @return nullptr if cannot load, otherwise return a TrainerConfigHelper.
200+
*/
201+
static std::shared_ptr<TrainerConfigHelper> create(
202+
const std::string& configFilename);
203+
196204
private:
197205
static std::string getConfigNameFromPassId(int passId,
198206
const std::string& modelPath);

paddle/trainer/tests/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,14 @@ add_test(NAME test_config_parser
8383
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
8484
python ${PROJ_ROOT}/paddle/trainer/tests/config_parser_test.py
8585
WORKING_DIRECTORY ${PROJ_ROOT}/paddle/)
86+
87+
88+
############# test_SgdLocalUpdaterForSparseNetwork ###########
89+
add_unittest_without_exec(test_SgdLocalUpdaterForSparseNetwork
90+
test_SgdLocalUpdaterForSparseNetwork.cpp)
91+
92+
add_test(NAME test_SgdLocalUpdaterForSparseNetwork
93+
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d
94+
${PROJ_ROOT}/python/
95+
${CMAKE_CURRENT_BINARY_DIR}/test_SgdLocalUpdaterForSparseNetwork
96+
WORKING_DIRECTORY ${PROJ_ROOT}/paddle/trainer/tests/sgd_local_updater_sparse_network/)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
train.list
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from paddle.trainer_config_helpers import *
2+
3+
define_py_data_sources2(
4+
train_list=["do_not_matter.txt"],
5+
test_list=None,
6+
module='sparse_updated_network_provider',
7+
obj='process')
8+
9+
settings(batch_size=100, learning_rate=1e-4)
10+
11+
outputs(
12+
classification_cost(
13+
input=fc_layer(
14+
size=10,
15+
act=SoftmaxActivation(),
16+
input=embedding_layer(
17+
size=64,
18+
input=data_layer(
19+
name='word_id', size=600000),
20+
param_attr=ParamAttr(sparse_update=True))),
21+
label=data_layer(
22+
name='label', size=10)))
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from paddle.trainer.PyDataProvider2 import *
2+
import random
3+
4+
5+
@provider(
6+
input_types={"word_id": integer_value(600000),
7+
"label": integer_value(10)},
8+
min_pool_size=0)
9+
def process(settings, filename):
10+
for _ in xrange(1000):
11+
yield random.randint(0, 600000 - 1), random.randint(0, 9)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/pserver/ParameterServer2.h"
17+
#include "paddle/trainer/Trainer.h"
18+
#include "paddle/utils/PythonUtil.h"
19+
#include "paddle/utils/Util.h"
20+
21+
P_DECLARE_bool(local);
22+
23+
static std::unique_ptr<paddle::Trainer> createTrainer(
24+
bool useGpu, size_t trainerCount, const std::string& configFilename) {
25+
FLAGS_use_gpu = useGpu;
26+
FLAGS_trainer_count = trainerCount;
27+
paddle::Trainer* trainer = new paddle::Trainer();
28+
29+
trainer->init(paddle::TrainerConfigHelper::create(configFilename));
30+
return std::unique_ptr<paddle::Trainer>(trainer);
31+
}
32+
33+
TEST(SgdLocalUpdater, RemoteSparseNNCpu) {
34+
FLAGS_ports_num_for_sparse = 1;
35+
FLAGS_num_passes = 1;
36+
FLAGS_local = false;
37+
std::vector<std::shared_ptr<paddle::ParameterServer2>> pservers;
38+
39+
for (int i = 0; i < FLAGS_ports_num + FLAGS_ports_num_for_sparse; ++i) {
40+
auto pserver =
41+
std::make_shared<paddle::ParameterServer2>("127.0.0.1", FLAGS_port + i);
42+
pserver->init();
43+
pserver->start();
44+
pservers.push_back(pserver);
45+
}
46+
47+
auto trainerPtr = createTrainer(false, 1, "sparse_updated_network.py");
48+
ASSERT_TRUE(trainerPtr != nullptr);
49+
paddle::Trainer& trainer = *trainerPtr;
50+
trainer.startTrain();
51+
trainer.train(1);
52+
trainer.finishTrain();
53+
}
54+
55+
TEST(SgdLocalUpdater, LocalSparseNNCpu) {
56+
FLAGS_local = true;
57+
auto trainerPtr = createTrainer(false, 1, "sparse_updated_network.py");
58+
ASSERT_TRUE(trainerPtr != nullptr);
59+
paddle::Trainer& trainer = *trainerPtr;
60+
trainer.startTrain();
61+
trainer.train(1);
62+
trainer.finishTrain();
63+
}
64+
// TEST(SgdLocalUpdater, SparseNNGpu) {
65+
// auto trainerPtr = createTrainer(true, 1, "sparse_updated_network.py");
66+
// ASSERT_TRUE(trainerPtr != nullptr);
67+
// paddle::Trainer& trainer = *trainerPtr;
68+
// trainer.startTrain();
69+
// trainer.train(1);
70+
// trainer.finishTrain();
71+
//}
72+
73+
int main(int argc, char** argv) {
74+
testing::InitGoogleTest(&argc, argv);
75+
paddle::initMain(argc, argv);
76+
paddle::initPython(argc, argv);
77+
return RUN_ALL_TESTS();
78+
}

python/paddle/trainer_config_helpers/data_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def define_py_data_source(file_list,
6969
"""
7070
if isinstance(file_list, list):
7171
file_list_name = 'train.list'
72-
if isinstance(cls, TestData):
72+
if cls == TestData:
7373
file_list_name = 'test.list'
7474
with open(file_list_name, 'w') as f:
7575
f.writelines(file_list)

0 commit comments

Comments
 (0)