Skip to content

Commit e9a1f2a

Browse files
committed
fix DeConv3D switch(imgSize*_, output*_)
1 parent 4fbc03d commit e9a1f2a

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

paddle/gserver/layers/DeConv3DLayer.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,27 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
5353

5454
size_t DeConv3DLayer::getSize() {
5555
CHECK_NE(inputLayers_.size(), 0UL);
56-
outputH_.clear();
57-
outputW_.clear();
58-
outputD_.clear();
56+
imgSizeW_.clear();
57+
imgSizeH_.clear();
58+
imgSizeD_.clear();
5959
N_.clear();
6060
NOut_.clear();
6161
size_t layerSize = 0;
6262
for (size_t i = 0; i < inputLayers_.size(); ++i) {
63-
outputW_.push_back(
64-
imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
65-
outputH_.push_back(imageSize(
66-
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
67-
outputD_.push_back(imageSize(
68-
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
69-
NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
70-
N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
63+
imgSizeW_.push_back(
64+
imageSize(outputW_[i], filterSize_[i], padding_[i], stride_[i], true));
65+
imgSizeH_.push_back(imageSize(
66+
outputH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
67+
imgSizeD_.push_back(imageSize(
68+
outputD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
69+
NOut_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
70+
N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
7171
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
7272
layerSize += NOut_[i] * numFilters_;
7373
}
74-
getOutput().setFrameHeight(outputH_[0]);
75-
getOutput().setFrameWidth(outputW_[0]);
76-
getOutput().setFrameDepth(outputD_[0]);
74+
getOutput().setFrameHeight(imgSizeH_[0]);
75+
getOutput().setFrameWidth(imgSizeW_[0]);
76+
getOutput().setFrameDepth(imgSizeD_[0]);
7777
return layerSize;
7878
}
7979

@@ -103,9 +103,9 @@ void DeConv3DLayer::forward(PassType passType) {
103103
}
104104
colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
105105
numFilters_,
106-
outputD_[i],
107-
outputH_[i],
108-
outputW_[i],
106+
imgSizeD_[i],
107+
imgSizeH_[i],
108+
imgSizeW_[i],
109109
filterSizeZ_[i],
110110
filterSizeY_[i],
111111
filterSize_[i],
@@ -144,9 +144,9 @@ void DeConv3DLayer::backward(const UpdateCallback &callback) {
144144
colBuf_->vol2Col(
145145
getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
146146
numFilters_,
147-
outputD_[i],
148-
outputH_[i],
149-
outputW_[i],
147+
imgSizeD_[i],
148+
imgSizeH_[i],
149+
imgSizeW_[i],
150150
filterSizeZ_[i],
151151
filterSizeY_[i],
152152
filterSize_[i],

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,26 +2302,27 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
23022302
conv->set_stride(2);
23032303
conv->set_stride_y(2);
23042304
conv->set_stride_z(2);
2305-
conv->set_img_size(IMAGE_SIZE);
2306-
conv->set_img_size_y(IMAGE_SIZE_Y);
2307-
conv->set_img_size_z(IMAGE_SIZE_Z);
2308-
conv->set_output_x(imageSize(conv->img_size(),
2305+
conv->set_output_x(IMAGE_SIZE);
2306+
conv->set_output_y(IMAGE_SIZE_Y);
2307+
conv->set_output_z(IMAGE_SIZE_Z);
2308+
2309+
conv->set_img_size(imageSize(conv->output_x(),
23092310
conv->filter_size(),
23102311
conv->padding(),
23112312
conv->stride(),
23122313
true));
2313-
conv->set_output_y(imageSize(conv->img_size_y(),
2314-
conv->filter_size_y(),
2315-
conv->padding_y(),
2316-
conv->stride_y(),
2317-
true));
2318-
conv->set_output_z(imageSize(conv->img_size_z(),
2319-
conv->filter_size_z(),
2320-
conv->padding_z(),
2321-
conv->stride_z(),
2322-
true));
2323-
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
2324-
conv->output_z() * NUM_FILTERS);
2314+
conv->set_img_size_y(imageSize(conv->output_y(),
2315+
conv->filter_size_y(),
2316+
conv->padding_y(),
2317+
conv->stride_y(),
2318+
true));
2319+
conv->set_img_size_z(imageSize(conv->output_z(),
2320+
conv->filter_size_z(),
2321+
conv->padding_z(),
2322+
conv->stride_z(),
2323+
true));
2324+
config.layerConfig.set_size(conv->img_size() * conv->img_size_y() *
2325+
conv->img_size_z() * NUM_FILTERS);
23252326
conv->set_groups(1);
23262327
conv->set_filter_channels(conv->channels() / conv->groups());
23272328
config.inputDefs.push_back(

0 commit comments

Comments
 (0)