Skip to content

Commit 3304de7

Browse files
authored
Merge pull request #48 from reyoung/master
Merge Baidu Changes into github
2 parents c3c76d6 + dbaabc9 commit 3304de7

22 files changed

+171
-168
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
33
project(paddle CXX C)
44
set(PADDLE_MAJOR_VERSION 0)
55
set(PADDLE_MINOR_VERSION 8)
6-
set(PADDLE_PATCH_VERSION 0b)
6+
set(PADDLE_PATCH_VERSION 0b0)
77
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})
88

99
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")

doc/build/docker_install.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ Docker is a tool designed to make it easier to create, deploy, and run applicati
88
### PaddlePaddle Docker images
99
There are six Docker images:
1010

11-
- paddledev/paddle:latest-cpu: PaddlePaddle CPU binary image.
12-
- paddledev/paddle:latest-gpu: PaddlePaddle GPU binary image.
13-
- paddledev/paddle:latest-cpu-devel: PaddlePaddle CPU binary image plus source code.
14-
- paddledev/paddle:latest-gpu-devel: PaddlePaddle GPU binary image plus source code.
15-
- paddledev/paddle:latest-cpu-demo: PaddlePaddle CPU binary image plus source code and demo
16-
- paddledev/paddle:latest-gpu-demo: PaddlePaddle GPU binary image plus source code and demo
11+
- paddledev/paddle:cpu-latest: PaddlePaddle CPU binary image.
12+
- paddledev/paddle:gpu-latest: PaddlePaddle GPU binary image.
13+
- paddledev/paddle:cpu-devel-latest: PaddlePaddle CPU binary image plus source code.
14+
- paddledev/paddle:gpu-devel-latest: PaddlePaddle GPU binary image plus source code.
15+
- paddledev/paddle:cpu-demo-latest: PaddlePaddle CPU binary image plus source code and demo
16+
- paddledev/paddle:gpu-demo-latest: PaddlePaddle GPU binary image plus source code and demo
1717

1818
Tags with latest will be replaced by a released version.
1919

@@ -23,15 +23,15 @@ You have to install Docker in your machine which has linux kernel version 3.10+
2323

2424
You can use ```docker pull ```to download images first, or just launch a container with ```docker run```:
2525
```bash
26-
docker run -it paddledev/paddle:lastest-cpu
26+
docker run -it paddledev/paddle:cpu-latest
2727
```
2828

2929
If you want to launch container with GPU support, you need to set some environment variables at the same time:
3030

3131
```bash
3232
export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}"
3333
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
34-
docker run -it paddledev/paddle:latest-gpu
34+
docker run -it paddledev/paddle:gpu-latest
3535
```
3636
3737
### Notice

doc/demo/imagenet_model/resnet_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa
165165

166166
### C++ Interface
167167

168-
First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`.
168+
First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`.
169169

170170
```
171171
train_list = 'train.list' if not is_test else None

doc/demo/rec/ml_regression.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers
257257
* Text Convolution Pooling Layer, `text_conv_pool
258258
<../../ui/api/trainer_config_helpers/networks.html
259259
#trainer_config_helpers.networks.text_conv_pool>`_
260-
* Declare Python Data Sources, `define_py_data_sources
260+
* Declare Python Data Sources, `define_py_data_sources2
261261
<../../ui/api/trainer_config_helpers/data_sources.html>`_
262262

263263
Data Provider

doc/ui/predict/predict_sample.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from py_paddle import swig_paddle, DataProviderWrapperConverter
16-
from paddle.trainer.PyDataProviderWrapper import DenseSlot
15+
from py_paddle import swig_paddle, DataProviderConverter
16+
from paddle.trainer.PyDataProvider2 import dense_vector
1717
from paddle.trainer.config_parser import parse_config
1818

1919
TEST_DATA = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -89,12 +89,12 @@
8989

9090

9191
def main():
92-
conf = parse_config("./mnist_model/trainer_config.conf.norm", "")
92+
conf = parse_config("./mnist_model/trainer_config.py", "")
9393
print conf.data_config.load_data_args
9494
network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
9595
assert isinstance(network, swig_paddle.GradientMachine) # For code hint.
9696
network.loadParameters("./mnist_model/")
97-
converter = DataProviderWrapperConverter(False, [DenseSlot(784)])
97+
converter = DataProviderConverter([dense_vector(784)])
9898
inArg = converter(TEST_DATA)
9999
print network.forwardTest(inArg)
100100

doc/ui/predict/swig_py_paddle_en.rst

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,35 @@ SWIG. The main steps of predict values in python are:
1010
* Predict
1111

1212
Here is a sample python script that shows the typical prediction process for the
13-
MNIST classification problem.
13+
MNIST classification problem. A complete sample code could be found at
14+
:code:`src_root/doc/ui/predict/predict_sample.py`.
1415

1516
.. literalinclude:: ./predict_sample.py
1617
:language: python
17-
:linenos:
18+
:lines: 15-18,90-100,101-104
1819

1920
The module that does the most of the job is py_paddle.swig_paddle, it's
2021
generated by SWIG and has complete documents, for more details you can use
2122
python's :code:`help()` function. Let's walk through the above python script:
2223

23-
* At the beginning, initialize PaddlePaddle with command line arguments(line 90).
24-
* Parse the configuration file that is used in training(line 93).
25-
* Create a neural network at line 95 according the parsed configuration, then
26-
load the trained parameters from model at line 97.
27-
* A utility class for data transformation is created at line 98.
24+
* At the beginning, use :code:`swig_paddle.initPaddle()` to initialize
25+
PaddlePaddle with command line arguments, for more about command line arguments
26+
see `Command Line Arguments <../cmd_argument/detail_introduction.html>`_.
27+
* Parse the configuration file that is used in training with :code:`parse_config()`.
28+
Because data to predict with always have no label, and output of prediction work
29+
normally is the output layer rather than the cost layer, so you should modify
30+
the configuration file accordingly before using it in the prediction work.
31+
* Create a neural network with
32+
:code:`swig_paddle.GradientMachine.createFromConfigproto()`, which takes the
33+
parsed configuration :code:`conf.model_config` as argument. Then load the
34+
trained parameters from the model with :code:`network.loadParameters()`.
35+
* Create a data converter object of utility class :code:`DataProviderConverter`.
2836
- Note: As swig_paddle can only accept C++ matrices, we offer a utility
29-
class DataProviderWraaperConverter that can accept the same input data with
30-
PyDataProviderWrapper, for more information please refer to document
37+
class DataProviderConverter that can accept the same input data with
38+
PyDataProvider2, for more information please refer to document
3139
of `PyDataProvider2 <../data_provider/pydataprovider2.html>`_.
32-
* Do the prediction and output the result at line 100, forwardTest is another
33-
utility class that directly takes the activations of the output layer.
40+
* Do the prediction with :code:`forwardTest()`, which takes the converted
41+
input data and outputs the activations of the output layer.
3442

3543
Here is a typical output:
3644

paddle/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ add_subdirectory(pserver)
77
add_subdirectory(trainer)
88
add_subdirectory(scripts)
99

10+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
11+
${CMAKE_CURRENT_SOURCE_DIR}/setup.py)
12+
1013
if(WITH_PREDICT_SDK)
1114
add_subdirectory(predict)
1215
endif()

paddle/cuda/src/hl_cuda_matrix.cu

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -266,25 +266,21 @@ template<int blockSize>
266266
__global__ void KeMatrixClassificationError(real* in_A,
267267
int* in_B,
268268
real* out_C,
269-
int dimM,
270269
int dimN) {
271270
__shared__ real max_s[blockSize];
272271
__shared__ int max_l[blockSize];
273-
int cnt = (dimN + blockSize -1) / blockSize;
274-
int tid = threadIdx.x;
275-
int lmt = tid;
276-
int index = 0;
277-
real t;
272+
const int tid = threadIdx.x;
273+
const int rowId = blockIdx.x;
278274

279275
max_s[tid] = -1e30f;
280-
for (int ii = 0; ii < cnt && lmt < dimN; ii++) {
281-
index = blockIdx.y*dimN + lmt;
282-
t = in_A[index];
283-
if (max_s[tid] < t) {
284-
max_s[tid] = t;
285-
max_l[tid] = lmt;
276+
in_A += rowId * dimN;
277+
real tmp;
278+
for (int colId = tid; colId < dimN; colId += blockSize) {
279+
tmp = in_A[colId];
280+
if (max_s[tid] < tmp) {
281+
max_s[tid] = tmp;
282+
max_l[tid] = colId;
286283
}
287-
lmt += blockSize;
288284
}
289285
__syncthreads();
290286

@@ -300,7 +296,7 @@ __global__ void KeMatrixClassificationError(real* in_A,
300296
__syncthreads();
301297

302298
if (tid == 0) {
303-
out_C[blockIdx.y] = (max_l[0] == in_B[blockIdx.y] ? 0 : 1.0f);
299+
out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f);
304300
}
305301
}
306302

@@ -313,12 +309,9 @@ void hl_matrix_classification_error(real* A_d,
313309
CHECK_NOTNULL(B_d);
314310
CHECK_NOTNULL(C_d);
315311

316-
int blocksX = 1;
317-
int blocksY = dimM;
318-
dim3 threads(1024, 1);
319-
dim3 grid(blocksX, blocksY);
320-
KeMatrixClassificationError<1024><<< grid, threads, 0, STREAM_DEFAULT >>>
321-
(A_d, B_d, C_d, dimM, dimN);
312+
// each sample is calculated by one block
313+
KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>>
314+
(A_d, B_d, C_d, dimN);
322315
CHECK_SYNC("hl_matrix_classification_error");
323316
}
324317

paddle/gserver/layers/CRFLayer.cpp

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -47,104 +47,49 @@ bool CRFLayer::init(const LayerMap& layerMap,
4747
// We don't need sequenceStartPositions because each sample of output_ is
4848
// for the cost of one sequence.
4949
setNeedSequenceInfo(false);
50-
if (useGpu_) {
51-
tmpCpuInput_.reserve(inputLayers_.size());
52-
for (size_t i = 0; i < inputLayers_.size(); i++) {
53-
tmpCpuInput_.push_back(Argument());
54-
}
55-
}
50+
5651
return true;
5752
}
5853

5954
void CRFLayer::forward(PassType passType) {
6055
Layer::forward(passType);
61-
if (useGpu_) {
62-
for (size_t i = 0; i < inputLayers_.size(); i++) {
63-
tmpCpuInput_[i].resizeAndCopyFrom(getInput(i), false, HPPL_STREAM_1);
64-
}
65-
VectorPtr cpuParameterValue;
66-
VectorPtr cpuParameterGradient;
67-
cpuParameterValue =
68-
Vector::create(parameter_->getBuf(PARAMETER_VALUE)->getSize(), false);
69-
cpuParameterValue->
70-
copyFrom(*parameter_->getBuf(PARAMETER_VALUE), HPPL_STREAM_1);
71-
if (parameter_->getBuf(PARAMETER_GRADIENT)) {
72-
cpuParameterGradient =
73-
Vector::create(parameter_->getBuf(PARAMETER_GRADIENT)->getSize(),
74-
false);
75-
cpuParameterGradient->
76-
copyFrom(*parameter_->getBuf(PARAMETER_GRADIENT), HPPL_STREAM_1);
77-
} else {
78-
cpuParameterGradient = nullptr;
79-
}
80-
forwardImp(tmpCpuInput_[0], tmpCpuInput_[1], cpuParameterValue,
81-
cpuParameterGradient);
82-
parameter_->getBuf(PARAMETER_VALUE)->copyFrom(*cpuParameterValue,
83-
HPPL_STREAM_1);
84-
if (parameter_->getBuf(PARAMETER_GRADIENT)) {
85-
parameter_->getBuf(PARAMETER_GRADIENT)->copyFrom(*cpuParameterGradient,
86-
HPPL_STREAM_1);
87-
}
88-
} else {
89-
forwardImp(getInput(0), getInput(1), parameter_->getBuf(PARAMETER_VALUE),
90-
parameter_->getBuf(PARAMETER_GRADIENT));
91-
}
92-
}
9356

94-
void CRFLayer::forwardImp(const Argument&output,
95-
const Argument& label,
96-
VectorPtr parameterValue,
97-
VectorPtr parameterGradient) {
57+
CHECK(!useGpu_) << "GPU is not supported";
58+
59+
const Argument& output = getInput(0);
60+
const Argument& label = getInput(1);
9861
CHECK(label.sequenceStartPositions);
9962
CHECK(label.ids);
10063

10164
int batchSize = output.getBatchSize();
10265
size_t numSequences = label.sequenceStartPositions->getSize() - 1;
10366
resizeOutput(numSequences, 1);
104-
std::vector<real> out(numSequences);
10567

10668
const int* starts = label.sequenceStartPositions->getData(false);
10769
CHECK_EQ(starts[numSequences], batchSize);
108-
VectorPtr cpuParameterValue;
109-
VectorPtr cpuParameterGradient;
110-
11170

11271
for (size_t i = 0; i < numSequences; ++i) {
11372
if (i >= crfs_.size()) {
11473
crfs_.emplace_back(numClasses_,
115-
parameterValue->getData(),
116-
parameterGradient
117-
? parameterGradient->getData()
74+
parameter_->getBuf(PARAMETER_VALUE)->getData(),
75+
parameter_->getBuf(PARAMETER_GRADIENT)
76+
? parameter_->getBuf(PARAMETER_GRADIENT)->getData()
11877
: nullptr);
11978
}
120-
out[i] = crfs_[i].forward(
79+
output_.value->getData()[i] = crfs_[i].forward(
12180
output.value->getData() + numClasses_ * starts[i],
12281
label.ids->getData() + starts[i], starts[i + 1] - starts[i]);
12382
}
124-
output_.value->copyFrom(out.data(), numSequences);
83+
12584
if (weightLayer_) {
12685
const MatrixPtr& weight = getInputValue(*weightLayer_);
12786
getOutputValue()->dotMul(*getOutputValue(), *weight);
12887
}
12988
}
13089

13190
void CRFLayer::backward(const UpdateCallback &callback) {
132-
(void)callback;
133-
if (useGpu_) {
134-
backwardImp(callback, tmpCpuInput_[0], tmpCpuInput_[1]);
135-
const_cast<Argument&>(getInput(0)).
136-
resizeAndCopyFrom(tmpCpuInput_[0], true, HPPL_STREAM_1);
137-
const_cast<Argument&>(getInput(1)).
138-
resizeAndCopyFrom(tmpCpuInput_[1], true, HPPL_STREAM_1);
139-
140-
} else {
141-
backwardImp(callback, getInput(0), getInput(1));
142-
}
143-
}
144-
145-
void CRFLayer::backwardImp(const UpdateCallback& callback,
146-
const Argument&output,
147-
const Argument& label) {
91+
const Argument& output = getInput(0);
92+
const Argument& label = getInput(1);
14893
const int* starts = label.sequenceStartPositions->getData(false);
14994
int numSequences = label.sequenceStartPositions->getSize() - 1;
15095

@@ -159,9 +104,11 @@ void CRFLayer::backwardImp(const UpdateCallback& callback,
159104
grad->mulScalar(weight);
160105
}
161106
}
107+
162108
if (coeff_ != real(1.0f)) {
163109
output.grad->mulScalar(coeff_);
164110
}
111+
165112
parameter_->incUpdate(callback);
166113
}
167114

paddle/gserver/layers/CRFLayer.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,14 @@ class CRFLayer : public Layer {
3232
explicit CRFLayer(const LayerConfig& config) : Layer(config) {}
3333
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
3434
virtual void forward(PassType passType);
35-
void forwardImp(const Argument&output, const Argument& label,
36-
VectorPtr parameterValue, VectorPtr parameterGradient);
3735
virtual void backward(const UpdateCallback& callback);
38-
void backwardImp(const UpdateCallback& callback, const Argument&output,
39-
const Argument& label);
4036

4137
protected:
4238
size_t numClasses_;
4339
ParameterPtr parameter_;
4440
std::vector<LinearChainCRF> crfs_;
4541
LayerPtr weightLayer_; // weight for each sequence
4642
real coeff_; // weight for the layer
47-
std::vector<Argument> tmpCpuInput_;
4843
};
4944

5045
} // namespace paddle

0 commit comments

Comments
 (0)