Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d66b37b
add variableSpec
Aurelius84 Aug 5, 2020
aa5f696
add unittest for tensorSpec and SimpleNet
Aurelius84 Aug 6, 2020
8d35104
refine code
Aurelius84 Aug 6, 2020
19dad37
fix unittest failed
Aurelius84 Aug 7, 2020
8d3e936
Merge remote-tracking branch 'upstream/develop' into input_spec
Aurelius84 Aug 7, 2020
165b7fc
fix unittest failed
Aurelius84 Aug 10, 2020
d4300eb
remove print statment
Aurelius84 Aug 10, 2020
7d6d5ce
specific gast==0.3.3
Aurelius84 Aug 10, 2020
0b91df0
rm uncomment code
Aurelius84 Aug 10, 2020
3dd8ddb
fix compatibility with Py2 and Py3
Aurelius84 Aug 12, 2020
4022d6b
add get_concrete_program interface
Aurelius84 Aug 12, 2020
a89c5d2
fix full_name unittest failed
Aurelius84 Aug 12, 2020
c6bfcb1
fix import error
Aurelius84 Aug 12, 2020
41eec74
fix typo
Aurelius84 Aug 12, 2020
9ac061c
Fix TODO in test_error
Aurelius84 Aug 13, 2020
e6d044e
add unittest for prune with input_spec
Aurelius84 Aug 14, 2020
ddf9863
refine code and comment
Aurelius84 Aug 14, 2020
5f4c264
fix conflict
Aurelius84 Aug 14, 2020
e191e94
rename Translator into StaticLayer
Aurelius84 Aug 17, 2020
7b49c91
Merge remote-tracking branch 'upstream/develop' into input_spec
Aurelius84 Aug 18, 2020
403849f
merge develop
Aurelius84 Aug 20, 2020
299d13d
fix unwrap function name
Aurelius84 Aug 20, 2020
4cffdcf
refine code and file name
Aurelius84 Aug 21, 2020
e76dabc
Merge remote-tracking branch 'upstream/develop' into input_spec
Aurelius84 Aug 21, 2020
fa4482e
refine Input.py code
Aurelius84 Aug 21, 2020
5e1ad4a
add verifying shape is None in inputSpec
Aurelius84 Aug 23, 2020
de68ea4
Merge remote-tracking branch 'upstream/develop' into input_spec
Aurelius84 Aug 23, 2020
5a343c1
rm __slots__ in input.py
Aurelius84 Aug 23, 2020
b85dd8e
fix sample code failure
Aurelius84 Aug 23, 2020
7424799
add sample code of InputSpec's interface
Aurelius84 Aug 23, 2020
f899443
fix Arg typo
Aurelius84 Aug 23, 2020
ab27d94
modify according reviewer advice
Aurelius84 Aug 24, 2020
c4bbd04
Merge upstream/develop
Aurelius84 Aug 25, 2020
5ab6f30
modify according to reviewer advise
Aurelius84 Aug 25, 2020
5cfcec3
fix try/except
Aurelius84 Aug 25, 2020
d911902
Enrich doc in InputSpec
Aurelius84 Aug 25, 2020
4d4901f
fix problem with parsing sample code in en doc
Aurelius84 Aug 25, 2020
0f0b863
Merge remote-tracking branch 'upstream/develop' into input_spec
Aurelius84 Aug 26, 2020
b922612
Replace concrete_program in model.py
Aurelius84 Aug 26, 2020
be9ddd4
modify declarative into to_static
Aurelius84 Aug 26, 2020
69c480b
Enrich comment about __hash__ in InputSpec
Aurelius84 Aug 26, 2020
214a6c3
refine sampple code according to TPM advise
Aurelius84 Aug 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import numpy as np
import sys
import paddle
from paddle.fluid import dygraph
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Linear
Expand Down Expand Up @@ -195,13 +196,16 @@ def save_quantized_model(self,
with dygraph.guard():
model.eval()
input_vars = []
for shape, dtype in zip(input_shape, input_dtype):
raw_data = np.random.random(shape)
input_data = raw_data[np.newaxis, :].astype(
dtype) if append_batch_size else raw_data.astype(dtype)
input_var = dygraph.to_variable(input_data)
input_vars.append(input_var)
outputs = prog_trans.get_output(model.forward, model, *input_vars)
for i, (shape, dtype) in enumerate(zip(input_shape, input_dtype)):
if append_batch_size:
shape = [None] + list(shape)
# Note(Aurelius84): need a elegant way to name this.
in_spec = paddle.static.InputSpec(shape, dtype, 'feed_%d' % i)
input_vars.append(in_spec)
# use `declarative` to convert dygraph into static program
model.forward = dygraph.jit.declarative(
model.forward, input_spec=input_vars)
outputs = model.forward.concrete_program.outputs
input_spec = [input_vars[i] for i in feed]
configs = dygraph.jit.SaveLoadConfig()
configs.separate_params = True
Expand Down
48 changes: 36 additions & 12 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@
import numpy
import six

from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len

DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator()
to_static_func = program_translator.get_func


def is_builtin(func):
Expand Down Expand Up @@ -63,7 +62,7 @@ def is_paddle_func(func):

def convert_call(func):
"""
Converts a function call which needs to be transformed to static fucntion.
Converts a function call which needs to be transformed to static function.

Args:
func (callable): A callable function or method to convert.
Expand Down Expand Up @@ -98,6 +97,15 @@ def dyfunc(x):
func_self = None
converted_call = None

# Function in convert_call may be decorated by another `@declarative`,
# in this case, unwraps it into a raw method or function.
if isinstance(func, StaticLayer):
instance = func._class_instance
if instance is not None:
func = func.dygraph_function.__get__(instance)
else:
func = func.dygraph_function

if is_builtin_len(func):
return convert_len

Expand All @@ -109,11 +117,27 @@ def dyfunc(x):
if func.__name__ == '<lambda>':
return func
try:
global_funcs = set([
fn for fn in func.__globals__.values() if inspect.isfunction(fn)
])
if func in global_funcs:
converted_call = to_static_func(func)
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.

# For example:
#
# @declarative
# def foo(x):
# return x
#
# `foo` will be converted into a wrapper class, suppose as `StaticLayer`.
# And `foo.__globals__['foo']` will still return this `StaticLayer` instead of
# `foo` function. So `isinstance(fn, StaticLayer)` is added here.
global_functions = set()
for fn in func.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticLayer):
global_functions.add(fn.dygraph_function)

if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except AttributeError:
# NOTE:
Expand All @@ -127,7 +151,7 @@ def dyfunc(x):
converted_call = None
elif inspect.ismethod(func):
try:
converted_call = to_static_func(func)
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have been decorated.
Expand All @@ -136,7 +160,7 @@ def dyfunc(x):
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
forward_func = to_static_func(func.forward)
forward_func = convert_to_static(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
Expand All @@ -146,7 +170,7 @@ def dyfunc(x):
else:
try:
call_func = func.__class__.__call__
converted_call = to_static_func(call_func)
converted_call = convert_to_static(call_func)
func_self = func
except Exception:
# NOTE:
Expand Down
Loading