-
Notifications
You must be signed in to change notification settings - Fork 6k
upgrade the TraceLayer.save_inference_model method with add file suffix automatically #31989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import six | ||
| import unittest | ||
| import paddle.nn as nn | ||
| import os | ||
|
|
||
|
|
||
| class SimpleFCLayer(nn.Layer): | ||
|
|
@@ -115,36 +116,41 @@ def test_save_inference_model_err(self): | |
| dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace( | ||
| self.layer, [in_x]) | ||
|
|
||
| dirname = './traced_layer_err_msg' | ||
| path = './traced_layer_err_msg' | ||
| with self.assertRaises(TypeError) as e: | ||
| traced_layer.save_inference_model([0]) | ||
| self.assertEqual( | ||
| "The type of 'dirname' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'str'>, but received <{} 'list'>. ". | ||
| "The type of 'path' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'str'>, but received <{} 'list'>. ". | ||
| format(self.type_str, self.type_str), str(e.exception)) | ||
| with self.assertRaises(TypeError) as e: | ||
| traced_layer.save_inference_model(dirname, [0], [None]) | ||
| traced_layer.save_inference_model(path, [0], [None]) | ||
| self.assertEqual( | ||
| "The type of 'each element of fetch' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. ". | ||
| format(self.type_str, self.type_str), str(e.exception)) | ||
| with self.assertRaises(TypeError) as e: | ||
| traced_layer.save_inference_model(dirname, [0], False) | ||
| traced_layer.save_inference_model(path, [0], False) | ||
| self.assertEqual( | ||
| "The type of 'fetch' in fluid.dygraph.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. ". | ||
| format(self.type_str, self.type_str, self.type_str), | ||
| str(e.exception)) | ||
| with self.assertRaises(TypeError) as e: | ||
| traced_layer.save_inference_model(dirname, [None], [0]) | ||
| traced_layer.save_inference_model(path, [None], [0]) | ||
| self.assertEqual( | ||
| "The type of 'each element of feed' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. ". | ||
| format(self.type_str, self.type_str), str(e.exception)) | ||
| with self.assertRaises(TypeError) as e: | ||
| traced_layer.save_inference_model(dirname, True, [0]) | ||
| traced_layer.save_inference_model(path, True, [0]) | ||
| self.assertEqual( | ||
| "The type of 'feed' in fluid.dygraph.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. ". | ||
| format(self.type_str, self.type_str, self.type_str), | ||
| str(e.exception)) | ||
| with self.assertRaises(ValueError) as e: | ||
| traced_layer.save_inference_model("") | ||
| self.assertEqual( | ||
| "The input path MUST be format of dirname/file_prefix [dirname\\file_prefix in Windows system], " | ||
| "but received file_prefix is empty string.", str(e.exception)) | ||
|
|
||
| traced_layer.save_inference_model(dirname) | ||
| traced_layer.save_inference_model(path) | ||
|
|
||
| def _train_simple_net(self): | ||
| layer = None | ||
|
|
@@ -174,5 +180,25 @@ def test_linear_net_with_none(self): | |
| [in_x]) | ||
|
|
||
|
|
||
| class TestTracedLayerSaveInferenceModel(unittest.TestCase): | ||
| """test save_inference_model will automaticlly create non-exist dir""" | ||
|
|
||
| def setUp(self): | ||
| self.save_path = "./nonexist_dir/fc" | ||
| import shutil | ||
| if os.path.exists(os.path.dirname(self.save_path)): | ||
| shutil.rmtree(os.path.dirname(self.save_path)) | ||
|
|
||
| def test_mkdir_when_input_path_non_exist(self): | ||
| fc_layer = SimpleFCLayer(3, 4, 2) | ||
| input_var = paddle.to_tensor(np.random.random([4, 3]).astype('float32')) | ||
| with fluid.dygraph.guard(): | ||
| dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should write new test cases for Paddle 2.0 path. So
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your suggestion,I will change these two in next pr. |
||
| fc_layer, inputs=[input_var]) | ||
| self.assertFalse(os.path.exists(os.path.dirname(self.save_path))) | ||
| traced_layer.save_inference_model(self.save_path) | ||
| self.assertTrue(os.path.exists(os.path.dirname(self.save_path))) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT (Not Important): However, since it is not so important and this PR needs many approves. I will approve it and you can change in next PR.
Python officially suggests to put imports at top of file, I think you should follow it in your case.
https://www.python.org/dev/peps/pep-0008/#imports
Let me explain my understanding of the pros and cons of importing at top of a file or in a function.
Importing at top of a file:
There are several cases I think we can import in a function:
But I don't think here we meet the cases for importing in a function :-) so I would suggest to import at the top.