diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 4aa4a2bfbbd17..d1f8bb703a1ed 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -81,7 +81,7 @@ paddle.fluid.layers.linear_chain_crf (ArgSpec(args=['input', 'label', 'param_att paddle.fluid.layers.crf_decoding (ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)), ('document', '462ddf2435e3392334e0c05ae57a01c4')) paddle.fluid.layers.cos_sim (ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None), ('document', 'cefab7c23ee5582727e8b22dffbafac8')) paddle.fluid.layers.cross_entropy (ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)), ('document', '535f1f6213dd7ca0fe5ed7cb4718c0e3')) -paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '30add751a0f99347a6257634c03ff254')) +paddle.fluid.layers.bpr_loss (ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6263dfdeb6c670fa0922c9cbc8fb1bf4')) paddle.fluid.layers.square_error_cost (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'f273bb26833ee88b349c4b8083e1dc67')) paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ee152a7ba3036e7b9ede9184545179b4')) paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)), ('document', 'b6543768e1afaa2ecb869709d6e9c7e2')) @@ -95,7 +95,7 @@ paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'po paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '859b887174d06f361658f69cb7c06d95')) paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '120f4323a3d7ed9c0916f15a59f0e497')) paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '581f9f99cd7f4b0cab9e0aad5fa0ea24')) -paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', 'e45e09e65a2658e07cad987222f0d9ab')) +paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787')) paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b0b8d53821716cd50c42e09b593f3feb')) paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', '03993955ab1e6d3044c44e6f17fc85e9')) paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'ec113c6a3686ac94f8fccd1a7953d445')) @@ -207,7 +207,7 @@ paddle.fluid.layers.logical_not (ArgSpec(args=['x', 'out', 'name'], varargs=None paddle.fluid.layers.clip (ArgSpec(args=['x', 'min', 'max', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0ce33756573c572da67302499455dbcd')) paddle.fluid.layers.clip_by_norm (ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a1ea0bc5a926f427458c4254ca022749')) paddle.fluid.layers.mean (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd638d915195ce86a8d7963b81110d4c8')) -paddle.fluid.layers.mul (ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)), ('document', 'ccd37fa6b53f074adbfb732d738c4c2d')) +paddle.fluid.layers.mul (ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)), ('document', '784b7e36cea88493f9e37a41b10fbf4d')) paddle.fluid.layers.sigmoid_cross_entropy_with_logits (ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False)), ('document', '180c284317ea45ef89a460d8d79c0b72')) paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '71426e02d240d0daedae81a02ca1c191')) paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a9221eaef53884a00654e028551b78e2')) @@ -227,19 +227,19 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'fe4481fb31363b09cfdd228fc6776ddf')) paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb')) paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '42d5155374f69786300d90d751956998')) -paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99')) +paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '07cb0d95a646dba1b9cc7cdce89e59f0')) paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '11bb8e62cc9256958eff3991fe4834da')) paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec')) paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '2985a372ac897ea4e13aced7f930d6f8')) paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '132b6e74ff642a392bd6b14c10aedc65')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) -paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'a07a44c2bacdcd09c1f5f35a96a0514e')) +paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', '94e2819b7c9715ea71b62e9c78f36b29')) paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6')) paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9')) -paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cf12066a3139026119f97f9d4381a1bd')) -paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e')) +paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029')) +paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6')) paddle.fluid.layers.shuffle (ArgSpec(args=['reader', 'buffer_size'], varargs=None, keywords=None, defaults=None), ('document', 'f967a73426db26f970bc70bfb03cffca')) paddle.fluid.layers.batch (ArgSpec(args=['reader', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', 'fcb24383c6eef2ca040ee824c26e22fd')) paddle.fluid.layers.double_buffer (ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '07e5b796674796eb1ef3fee9c10d24e3')) @@ -267,8 +267,8 @@ paddle.fluid.layers.argsort (ArgSpec(args=['input', 'axis', 'name'], varargs=Non paddle.fluid.layers.ones (ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)), ('document', 'b402489c62e668df42e7daceb63c142b')) paddle.fluid.layers.zeros (ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)), ('document', 'c155e2efc56ffa5ed4658cca0272e491')) paddle.fluid.layers.reverse (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None), ('document', '8ee7cb6ca639e7460e825f953b65d94d')) -paddle.fluid.layers.has_inf (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '8f8c0306117ea441f20dcbbdba1f0ecc')) -paddle.fluid.layers.has_nan (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '2e53e83127dbfd86e7098bdfe9a549e8')) +paddle.fluid.layers.has_inf (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '51a0fa1cfaf2507c00a215adacdb8a63')) +paddle.fluid.layers.has_nan (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '129cf426e71452fe8276d616a6dc21ae')) paddle.fluid.layers.isfinite (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '0a437011c3906079fd8947ed3e52d292')) paddle.fluid.layers.range (ArgSpec(args=['start', 'end', 'step', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '2ec937ede953ded2fdff2675883900bb')) paddle.fluid.layers.linspace (ArgSpec(args=['start', 'stop', 'num', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '495e21e9a848c2d075a102802fc67756')) @@ -309,9 +309,9 @@ paddle.fluid.layers.StaticRNN.step (ArgSpec(args=['self'], varargs=None, keyword paddle.fluid.layers.StaticRNN.step_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', '903387ec11f3d0bf46821d31a68cffa5')) paddle.fluid.layers.StaticRNN.step_output (ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None), ('document', '252890d4c3199a7623ab8667e13fd837')) paddle.fluid.layers.StaticRNN.update_memory (ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None), ('document', '7a0000520f179f35239956a5ba55119f')) -paddle.fluid.layers.reorder_lod_tensor_by_rank (ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None), ('document', '3545f529ef04e8f6ecb76b47fa3df01a')) -paddle.fluid.layers.Print (ArgSpec(args=['input', 'first_n', 'message', 'summarize', 'print_tensor_name', 'print_tensor_type', 'print_tensor_shape', 'print_tensor_lod', 'print_phase'], varargs=None, keywords=None, defaults=(-1, None, -1, True, True, True, True, 'both')), ('document', '5fef91b0e21c93610785f2b1f7161732')) -paddle.fluid.layers.is_empty (ArgSpec(args=['x', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', 'bbe578dbb49ad13e15b014e98c22b519')) +paddle.fluid.layers.reorder_lod_tensor_by_rank (ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None), ('document', '5b552a1f0f7eb4dacb768a975ba15d08')) +paddle.fluid.layers.Print (ArgSpec(args=['input', 'first_n', 'message', 'summarize', 'print_tensor_name', 'print_tensor_type', 'print_tensor_shape', 'print_tensor_lod', 'print_phase'], varargs=None, keywords=None, defaults=(-1, None, -1, True, True, True, True, 'both')), ('document', 'a222dbad457441941e50b812e5af9c7e')) +paddle.fluid.layers.is_empty (ArgSpec(args=['x', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', '3011dc695f490afdf504dc24f628319a')) paddle.fluid.layers.sigmoid (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a4e395ab004e7da34e94a0a1f9eee183')) paddle.fluid.layers.logsigmoid (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5f2508c52e0a797bb9bd5e29d79ede78')) paddle.fluid.layers.exp (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '41c976b68542f4cbee178640f765d845')) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 2b3d26c396032..c19d293f09477 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -86,7 +86,7 @@ inference_analysis_test(test_analyzer_small_dam SRCS analyzer_dam_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${DAM_SMALL_INSTALL_DIR}/model --infer_data=${DAM_SMALL_INSTALL_DIR}/data.txt --max_turn_num=1) -# save model +#save model inference_analysis_api_test(test_analyzer_save_model ${DAM_SMALL_INSTALL_DIR} analyzer_save_model_tester.cc) # chinese_ner diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index a3eac7b200c37..cfbb3b1546103 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -321,7 +321,6 @@ TEST(Analyzer_dam, compare_determine) { CompareDeterministic(reinterpret_cast(&cfg), input_slots_all); } - // Save optim model TEST(Analyzer_dam, save_optim_model) { AnalysisConfig cfg; diff --git a/paddle/fluid/inference/tests/api/analyzer_save_model_tester.cc b/paddle/fluid/inference/tests/api/analyzer_save_model_tester.cc index 578b420ea9247..4d99bbd36ffc9 100644 --- a/paddle/fluid/inference/tests/api/analyzer_save_model_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_save_model_tester.cc @@ -34,7 +34,8 @@ TEST(Analyzer, save_model) { AnalysisConfig cfg; SetConfig(&cfg); cfg.SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param"); - std::string optimModelPath = FLAGS_infer_model + "/saved_optim_model"; + // ensure the path being unique + std::string optimModelPath = FLAGS_infer_model + "/only_for_save_model_test"; mkdir(optimModelPath.c_str(), 0777); SaveOptimModel(&cfg, optimModelPath); diff --git a/paddle/fluid/operators/ngraph/ngraph_bridge.cc b/paddle/fluid/operators/ngraph/ngraph_bridge.cc index dafc31b546e3c..4ff50935d6c78 100644 --- a/paddle/fluid/operators/ngraph/ngraph_bridge.cc +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include +#include #include #include "ngraph/ngraph.hpp" @@ -24,6 +25,8 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/ngraph_helper.h" +constexpr int64_t kNoPadding = -1; + namespace paddle { namespace operators { @@ -31,6 +34,34 @@ bool NgraphBridge::isRegister(const std::string& str) { return ops::NgraphSingleton::Lookup(str); } +bool NgraphBridge::isSupported( + const std::unique_ptr& op) { + static std::unordered_set skip_op_list{"reshape", "reshape2", + "lookup_table"}; + bool result = true; + auto& op_type = op->Type(); + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + if (!isRegister(op_type)) { + if (skip_op_list.count(op_type)) { + if (op_type == "lookup_table") { + if (op_attrs.Get("is_sparse") || + (op_attrs.Get("padding_idx") != kNoPadding)) { + result = false; + } + } else if ((op_type == "reshape") || (op_type == "reshape2")) { + if (op->Input("Shape") != paddle::framework::kEmptyVarName) { + result = false; + } + } else { + result = false; + } + } + } else { + result = false; + } + return result; +} + void NgraphBridge::BuildNgNode( const std::shared_ptr& op) { auto& op_type = op->Type(); diff --git a/paddle/fluid/operators/ngraph/ngraph_bridge.h b/paddle/fluid/operators/ngraph/ngraph_bridge.h index b609c28495923..0b43ec53874d9 100644 --- a/paddle/fluid/operators/ngraph/ngraph_bridge.h +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.h @@ -39,6 +39,8 @@ class NgraphBridge { static bool isRegister(const std::string& str); + static bool isSupported(const std::unique_ptr& op); + private: std::shared_ptr< std::unordered_map>> diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 2486ae6bb5caf..e459bb9edc686 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -134,12 +134,11 @@ static std::vector> NgraphOpIntervals( int pivot = left; while (pivot < right) { auto op_type = ops->at(pivot)->Type(); - if (NgraphBridge::isRegister(op_type)) { + if (!NgraphBridge::isSupported(ops->at(pivot))) { ++pivot; } else { int start = pivot, end = start; - while (pivot < right && - (!NgraphBridge::isRegister(ops->at(pivot)->Type()))) { + while (pivot < right && (NgraphBridge::isSupported(ops->at(pivot)))) { ++pivot; ++end; } diff --git a/paddle/fluid/operators/ngraph/ops/activation_op.h b/paddle/fluid/operators/ngraph/ops/activation_op.h index a66ec65a336f8..ef6c11bce706a 100644 --- a/paddle/fluid/operators/ngraph/ops/activation_op.h +++ b/paddle/fluid/operators/ngraph/ops/activation_op.h @@ -37,6 +37,16 @@ void BuildReluGradNode( platform::SetOutputNode(op, "X@GRAD", relu_grad, ngb_node_map); } +void BuildSquareNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "X", ngb_node_map); + auto out = input * input; + platform::SetOutputNode(op, "Out", out, ngb_node_map); +} + void BuildTanhGradNode( const std::shared_ptr& op, std::shared_ptr< @@ -55,4 +65,5 @@ void BuildTanhGradNode( } // namespace paddle REGISTER_NG_OP(relu_grad, BuildReluGradNode); +REGISTER_NG_OP(square, BuildSquareNode); REGISTER_NG_OP(tanh_grad, BuildTanhGradNode); diff --git a/paddle/fluid/operators/ngraph/ops/reshape_op.h b/paddle/fluid/operators/ngraph/ops/reshape_op.h new file mode 100644 index 0000000000000..be3d38af492d1 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/reshape_op.h @@ -0,0 +1,107 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +ngraph::Shape calc_output_shape(const ngraph::Shape& input_shape, + const std::vector& v_shape) { + auto out_shape = v_shape; + for (size_t i = 0; i < v_shape.size(); ++i) { + if (v_shape[i] == 0) { + out_shape[i] = input_shape[i]; + } + } + int size_input = ngraph::shape_size(input_shape); + int size_out = 1; + for (auto o : out_shape) { + if (o > 0) size_out *= o; + } + for (auto& o : out_shape) { + if (o == -1) o = size_input / size_out; + } + return ngraph::Shape(out_shape.begin(), out_shape.end()); +} + +template +static void BuildReshapeNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + std::shared_ptr input = + platform::GetInputNode(op, "X", ngb_node_map); + auto input_shape = input->get_shape(); + // TODO(mozga-intel) The vector of shape is not supported yet, that's + // asDispensable() operator" + std::shared_ptr shape = + platform::GetInputNode(op, "Shape", ngb_node_map); + + auto op_attrs = framework::AttrReader(op->Attrs()); + std::vector v_shape = op_attrs.Get>("shape"); + auto out = input; + if (shape != nullptr) { + ngraph::Shape new_shape; + for (auto& it : shape->get_shape()) { + new_shape.push_back(it); + } + out = platform::NgReshaper(input, shape->get_shape()); + } else { + auto out_shape = calc_output_shape(input_shape, v_shape); + out = platform::NgReshaper(input, out_shape); + } + + if (is_v2) { + platform::SetOutputNode(op, "XShape", input, ngb_node_map); + } + platform::SetOutputNode(op, "Out", out, ngb_node_map); +} + +template +void BuildReshapeGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + std::shared_ptr input; + if (is_v2) { + input = paddle::platform::GetInputNode(op, "XShape", ngb_node_map); + } else { + input = paddle::platform::GetInputNode(op, "X", ngb_node_map); + } + auto dx = platform::NgReshaper(dout, input->get_shape()); + paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(reshape, BuildReshapeNode); +REGISTER_NG_OP(reshape2, BuildReshapeNode); +REGISTER_NG_OP(reshape_grad, BuildReshapeGradNode); +REGISTER_NG_OP(reshape2_grad, BuildReshapeGradNode); diff --git a/paddle/fluid/platform/ngraph_helper.h b/paddle/fluid/platform/ngraph_helper.h index 9e6521653b80a..9c75f8dc6342e 100644 --- a/paddle/fluid/platform/ngraph_helper.h +++ b/paddle/fluid/platform/ngraph_helper.h @@ -77,9 +77,7 @@ std::shared_ptr GetNode( std::unordered_map>> ngb_node_map) { auto& var_names = var_map.at(name); - PADDLE_ENFORCE_EQ(var_names.size(), 1, - "op %s name %s expects one associated var", op->Type(), - name); + if (var_names.size() == 0) return nullptr; if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { return (*ngb_node_map)[var_names[0]]; } else { @@ -189,6 +187,22 @@ inline void TrimTrailingSingularDims(ngraph::Shape* shape) { } } } + +ngraph::element::Type GetNgType(paddle::framework::proto::VarType::Type dtype) { + ngraph::element::Type ng_dtype; + if (dtype == paddle::framework::proto::VarType::FP32) { + ng_dtype = ngraph::element::f32; + } else if (dtype == paddle::framework::proto::VarType::FP64) { + ng_dtype = ngraph::element::f64; + } else if (dtype == paddle::framework::proto::VarType::INT64) { + ng_dtype = ngraph::element::i64; + } else if (dtype == paddle::framework::proto::VarType::INT32) { + ng_dtype = ngraph::element::i32; + } else { + PADDLE_THROW("unsupported data type: %s", dtype); + } + return ng_dtype; +} } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 40804015b9c7b..f02e8b55722f3 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -229,15 +229,17 @@ def forward(self, input): 'use_mkldnn': False, }) - pre_act = self._helper.create_variable_for_type_inference( - dtype=self._dtype) - - self._helper.append_op( - type='elementwise_add', - inputs={'X': [pre_bias], - 'Y': [self._bias_param]}, - outputs={'Out': [pre_act]}, - attrs={'axis': 1}) + if self._bias_param is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [self._bias_param]}, + outputs={'Out': [pre_act]}, + attrs={'axis': 1}) + else: + pre_act = pre_bias # Currently, we don't support inplace in dygraph mode return self._helper.append_activation(pre_act, act=self._act) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index f53394babf628..5f5a2d105a8e6 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -169,12 +169,16 @@ def Print(input, Examples: - .. code-block:: python + + import paddle.fluid as fluid + + input = fluid.layers.data(name="input", shape=[4, 32, 32], dtype="float32") + fluid.layers.Print(input, message = "The content of input layer:") + # value = some_layer(...) + # Print(value, summarize=10, + # message="The content of some_layer: ") - value = some_layer(...) - Print(value, summarize=10, - message="The content of some_layer: ") ''' helper = LayerHelper('print', **locals()) helper.append_op( @@ -2060,8 +2064,31 @@ def _assert_in_rnn_block_(self, method): method)) -@autodoc() +@templatedoc() def reorder_lod_tensor_by_rank(x, rank_table): + """ + ${comment} + + Args: + + x(${x_type}): ${x_comment} + rank_table(${rank_table_type}): ${rank_table_type} + + Returns: + out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + data_desc = (['input', [9], 0], ['ref', [5], 1]) + data = fluid.layers.data(name=data_desc[0][0], shape=data_desc[0][1]) + rank_data = fluid.layers.data(name=data_desc[1][0], shape=data_desc[1][1]) + table = fluid.layers.control_flow.lod_rank_table(rank_data) + new_data = fluid.layers.reorder_lod_tensor_by_rank( + x=data, rank_table=table) + + """ helper = LayerHelper('reorder_lod_tensor_by_rank', **locals()) helper.is_instance('x', Variable) helper.is_instance('rank_table', Variable) @@ -2094,9 +2121,12 @@ def is_empty(x, cond=None): Examples: .. code-block:: python + import paddle.fluid as fluid + input = fluid.layers.data(name="input", shape=[4, 32, 32], dtype="float32") res = fluid.layers.is_empty(x=input) # or: - fluid.layers.is_empty(x=input, cond=res) + # fluid.layers.is_empty(x=input, cond=res) + """ helper = LayerHelper("is_empty", **locals()) if cond is None: diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index a2538fa0f9d29..db1b5599ab465 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -894,6 +894,7 @@ def open_files(filenames, Examples: .. code-block:: python + import paddle.fluid. as fluid reader = fluid.layers.io.open_files(filenames=['./data1.recordio', './data2.recordio'], shapes=[(3,224,224), (1,)], @@ -1089,15 +1090,16 @@ def read_file(reader): Examples: .. code-block:: python - + + import paddle.fluid as fluid data_file = fluid.layers.open_files( filenames=['mnist.recordio'], shapes=[(-1, 748), (-1, 1)], lod_levels=[0, 0], dtypes=["float32", "int64"]) - data_file = fluid.layers.double_buffer( + data_file = fluid.layers.double_buffer( fluid.layers.batch(data_file, batch_size=64)) - input, label = fluid.layers.read_file(data_file) + input, label = fluid.layers.read_file(data_file) """ helper = LayerHelper('read_file') out = [ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b889560a0caff..1e68732c1ff49 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1544,14 +1544,16 @@ def cross_entropy2(input, label, ignore_index=kIgnoreIndex): def bpr_loss(input, label, name=None): """ - Bayesian Personalized Ranking Loss Operator. + **Bayesian Personalized Ranking Loss Operator** This operator belongs to pairwise ranking loss. Label is the desired item. The loss at a given point in one session is defined as: - $Y[i] = -\frac{1}{N_{i}-1} * \sum_{0\le j(https://arxiv.org/abs/1511.06939) + neural networks>. Args: input (Variable|list): a 2-D tensor with shape [N x D], where N is the @@ -1567,9 +1569,15 @@ def bpr_loss(input, label, name=None): Examples: .. code-block:: python + import paddle.fluid as fluid + + neg_size = 10 + label = fluid.layers.data( + name="label", shape=[1], dtype="int64") + predict = fluid.layers.data( + name="predict", shape=[neg_size + 1], dtype="float32") cost = fluid.layers.bpr_loss(input=predict, label=label) """ - helper = LayerHelper('bpr_loss', **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( @@ -3207,9 +3215,11 @@ def data_norm(input, Examples: .. code-block:: python + + import paddle.fluid as fluid - data = fluid.layers.data(input=x, size=200, param_attr='fc1.w') - hidden2 = fluid.layers.data_norm(input=hidden1) + hidden1 = fluid.layers.data(name="hidden1", shape=[200]) + hidden2 = fluid.layers.data_norm(name="hidden2", input=hidden1) """ helper = LayerHelper('data_norm', **locals()) dtype = helper.input_dtype() @@ -9875,6 +9885,18 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataX = fluid.layers.data(name="dataX", append_batch_size = False, shape=[2, 5], dtype="float32") + dataY = fluid.layers.data(name="dataY", append_batch_size = False, shape=[5, 3], dtype="float32") + output = fluid.layers.mul(dataX, dataY, + x_num_col_dims = 1, + y_num_col_dims = 1) + + """ helper = LayerHelper("mul", **locals()) @@ -10472,8 +10494,16 @@ def teacher_student_sigmoid_loss(input, Examples: .. code-block:: python + + import paddle.fluid as fluid + batch_size = 64 + label = fluid.layers.data( + name="label", shape=[batch_size, 1], dtype="int64", append_batch_size=False) + similarity = fluid.layers.data( + name="similarity", shape=[batch_size, 1], dtype="float32", append_batch_size=False) cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label) + """ helper = LayerHelper('teacher_student_sigmoid_loss', **locals()) out = helper.create_variable(dtype=input.dtype) @@ -11358,7 +11388,7 @@ def continuous_value_model(input, cvm, use_cvm=True): cvm (Variable): a 2-D Tensor with shape [N x 2], where N is the batch size, 2 is show and click. use_cvm (bool): use cvm or not. if use cvm, the output dim is the same as input if don't use cvm, the output dim is input dim - 2(remove show and click) - (cvm op is a customized op, which input is a sequence has embedd_with_cvm default, so we need an op named cvm to decided whever use it or not.) + (cvm op is a customized op, which input is a sequence has embed_with_cvm default, so we need an op named cvm to decided whever use it or not.) Returns: diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index f1ad1cb9859e5..44c75be64f429 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -741,6 +741,14 @@ def has_inf(x): Returns: Variable: The tensor variable storing the output, only a bool value. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + data = fluid.layers.data(name="input", shape=[4, 32, 32], dtype="float32") + res = fluid.layers.has_inf(data) + """ helper = LayerHelper("isinf", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -757,6 +765,14 @@ def has_nan(x): Returns: Variable: The tensor variable storing the output, only a bool value. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + data = fluid.layers.data(name="input", shape=[4, 32, 32], dtype="float32") + res = fluid.layers.has_nan(data) + """ helper = LayerHelper("isnan", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/fluid/param_attr.py b/python/paddle/fluid/param_attr.py index b7ce1c0e4f59a..1778f4b55e7f9 100644 --- a/python/paddle/fluid/param_attr.py +++ b/python/paddle/fluid/param_attr.py @@ -202,11 +202,12 @@ class WeightNormParamAttr(ParamAttr): Examples: .. code-block:: python - + + import paddle.fluid as fluid data = fluid.layers.data(name="data", shape=[3, 32, 32], dtype="float32") fc = fluid.layers.fc(input=data, size=1000, - param_attr=WeightNormParamAttr( + param_attr=fluid.WeightNormParamAttr( dim=None, name='weight_norm_param')) diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py index c7d62bd8ae1c8..3c1db3bf6406c 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_activation_ngraph_op.py @@ -18,7 +18,7 @@ import numpy as np import paddle.fluid.core as core from paddle.fluid.tests.unittests.op_test import OpTest -from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestRelu, TestTanh +from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestSquare, TestRelu, TestTanh class TestNGRAPHReluDim4(TestRelu): diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_reshape_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_reshape_ngraph_op.py new file mode 100644 index 0000000000000..cffa283271439 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_reshape_ngraph_op.py @@ -0,0 +1,23 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, sys +sys.path.append("../") + +from test_reshape_op import TestReshapeOp, TestReshapeOpDimInfer1, TestReshapeOpDimInfer2, TestReshapeOpWithInputShape + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index d9ef08b3c491b..a0cfb27d47fd2 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -231,7 +231,8 @@ def test_resnet_float32(self): seed = 90 batch_size = train_parameters["batch_size"] - batch_num = 20 + batch_num = 10 + with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py index 77e6fc2734232..74560535074f2 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py @@ -71,7 +71,7 @@ def test_resnet_sort_gradient_float32(self): seed = 90 batch_size = train_parameters["batch_size"] - batch_num = 20 + batch_num = 10 with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed diff --git a/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py b/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py index 3f3f92cde57c8..ae6a73904d6b5 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_se_resnext.py @@ -315,7 +315,7 @@ def test_se_resnext_float32(self): seed = 90 batch_size = train_parameters["batch_size"] - batch_num = 2 + batch_num = 1 epoch_num = 1 with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed diff --git a/tools/document_preview.sh b/tools/document_preview.sh index d0e9b3178a664..17d9a1d10a3ba 100755 --- a/tools/document_preview.sh +++ b/tools/document_preview.sh @@ -1,10 +1,12 @@ #!/bin/bash -PADDLE_ROOT=/paddle +PADDLE_ROOT=/home +mkdir ${PADDLE_ROOT} cd ${PADDLE_ROOT} +pip install /paddle/build/opt/paddle/share/wheels/*.whl git clone https://github.com/PaddlePaddle/FluidDoc git clone https://github.com/tianshuo78520a/PaddlePaddle.org.git -sh ${PADDLE_ROOT}/FluidDoc/doc/fluid/api/gen_doc.sh -pip install ${PADDLE_ROOT}/build/opt/paddle/share/wheels/*.whl +cd ${PADDLE_ROOT}/FluidDoc/doc/fluid/api +sh gen_doc.sh apt-get update && apt-get install -y python-dev build-essential cd ${PADDLE_ROOT}/PaddlePaddle.org/portal pip install -r requirements.txt