Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cached_property

import paddle
from paddle.amp.auto_cast import amp_state
from paddle.base import framework
Expand All @@ -20,7 +22,6 @@
UniqueNameGenerator,
guard as UniqueNameGuard,
)
from paddle.static import Program
from paddle.utils import flatten, is_sequence

from .utils import Cache, Singleton, map_if_extend, meta_str
Expand Down Expand Up @@ -105,8 +106,6 @@ class VariableCreator:

def __init__(self):
self.var_cache = {}
self.main_program = Program()
self.startup_program = Program()
self.var_name_generator = UniqueNameGenerator("infer_meta_variable_")

def gen_name(self, meta):
Expand All @@ -115,6 +114,30 @@ def gen_name(self, meta):
name += f"_{l}"
return name

@cached_property
def legacy_programs(self):
# Just for PIR and legacy IR compatibility.
# This can be removed after PIR become default state.
return (paddle.static.Program(), paddle.static.Program())

@cached_property
def pir_programs(self):
return (paddle.static.Program(), paddle.static.Program())

@property
def main_program(self):
if paddle.base.framework.use_pir_api():
return self.pir_programs[0]
else:
return self.legacy_programs[0]

@property
def startup_program(self):
if paddle.base.framework.use_pir_api():
return self.pir_programs[1]
else:
return self.legacy_programs[1]

def create_var(self, meta):
if paddle.base.framework.use_pir_api():
with paddle.static.program_guard(
Expand Down
5 changes: 5 additions & 0 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def test_sot_only(fn):
return fn


def test_legacy_only(fn):
fn = set_ir_mode(IrMode.LEGACY_IR)(fn)
return fn


def test_pir_only(fn):
fn = set_ir_mode(IrMode.PIR_EXE)(fn)
return fn
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/simnet_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import paddle
import paddle.base.param_attr as attr
from paddle.jit.api import to_static
from paddle.nn import Layer


Expand Down Expand Up @@ -484,7 +483,6 @@ def __init__(self, conf_dict):
self.bow_layer_po = FCLayer(self.bow_dim, None, "fc").ops()
self.softmax_layer = FCLayer(2, "softmax", "cos_sim").ops()

@to_static
def forward(self, left, right):
"""
Forward network
Expand Down
5 changes: 1 addition & 4 deletions test/dygraph_to_static/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_exe_and_pir_api,
)

Expand Down Expand Up @@ -188,11 +187,9 @@ def set_func(self):
self.func = paddle.jit.to_static(full_graph=True)(test_not_var_cast)

@test_ast_only
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_cast_result(self):
self.set_func()
# breakpoint()
# print("run once!!!")
res = self.do_test()
self.assertTrue(type(res) == int, msg='The casted dtype is not int.')
ref_val = int(self.input)
Expand Down
54 changes: 31 additions & 23 deletions test/dygraph_to_static/test_mobile_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

import paddle
from paddle import base
from paddle.base.framework import unique_name
from paddle.base.param_attr import ParamAttr
from paddle.jit.api import to_static
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import BatchNorm, Linear

Expand Down Expand Up @@ -267,7 +267,6 @@ def __init__(self, scale=1.0, class_dim=1000):
bias_attr=ParamAttr(name="fc7_offset"),
)

@to_static
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
Expand Down Expand Up @@ -433,7 +432,6 @@ def __init__(self, class_dim=1000, scale=1.0):
bias_attr=ParamAttr(name="fc10_offset"),
)

@to_static
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
Expand Down Expand Up @@ -496,7 +494,9 @@ class Args:
print_step = 1
train_step = 10
place = (
base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
model_save_dir = None
model_save_prefix = None
Expand All @@ -507,15 +507,20 @@ class Args:

def train_mobilenet(args, to_static):
paddle.jit.enable_to_static(to_static)
with base.dygraph.guard(args.place):

with unique_name.guard():
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

if args.model == "MobileNetV1":
net = MobileNetV1(class_dim=args.class_dim, scale=1.0)
net = paddle.jit.to_static(
MobileNetV1(class_dim=args.class_dim, scale=1.0)
)
elif args.model == "MobileNetV2":
net = MobileNetV2(class_dim=args.class_dim, scale=1.0)
net = paddle.jit.to_static(
MobileNetV2(class_dim=args.class_dim, scale=1.0)
)
else:
print(
"wrong model name, please try model = MobileNetV1 or MobileNetV2"
Expand Down Expand Up @@ -618,34 +623,37 @@ def predict_static(args, data):
feed={feed_target_names[0]: data},
fetch_list=fetch_targets,
)
paddle.disable_static()
return pred_res[0]


def predict_dygraph(args, data):
paddle.jit.enable_to_static(False)
with base.dygraph.guard(args.place):
if args.model == "MobileNetV1":
model = MobileNetV1(class_dim=args.class_dim, scale=1.0)
elif args.model == "MobileNetV2":
model = MobileNetV2(class_dim=args.class_dim, scale=1.0)
# load dygraph trained parameters
model_dict = paddle.load(args.dy_state_dict_save_path + '.pdparams')
model.set_dict(model_dict)
model.eval()
if args.model == "MobileNetV1":
model = paddle.jit.to_static(
MobileNetV1(class_dim=args.class_dim, scale=1.0)
)
elif args.model == "MobileNetV2":
model = paddle.jit.to_static(
MobileNetV2(class_dim=args.class_dim, scale=1.0)
)
# load dygraph trained parameters
model_dict = paddle.load(args.dy_state_dict_save_path + '.pdparams')
model.set_dict(model_dict)
model.eval()

pred_res = model(base.dygraph.to_variable(data))
pred_res = model(base.dygraph.to_variable(data))

return pred_res.numpy()
return pred_res.numpy()


def predict_dygraph_jit(args, data):
with base.dygraph.guard(args.place):
model = paddle.jit.load(args.model_save_prefix)
model.eval()
model = paddle.jit.load(args.model_save_prefix)
model.eval()

pred_res = model(data)
pred_res = model(data)

return pred_res.numpy()
return pred_res.numpy()


def predict_analysis_inference(args, data):
Expand Down
133 changes: 65 additions & 68 deletions test/dygraph_to_static/test_resnet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,76 +37,73 @@ def train(to_static, build_strategy=None):
"""
Tests model decorated by `dygraph_to_static_output` in static graph mode. For users, the model is defined in dygraph mode and trained in static graph mode.
"""
with base.dygraph.guard(place):
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

for epoch in range(epoch_num):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0

for batch_id in range(100):
start_time = time.time()
img = paddle.to_tensor(
np.random.random([batch_size, 3, 224, 224]).astype(
'float32'
)
)
label = paddle.to_tensor(
np.random.randint(0, 100, [batch_size, 1], dtype='int64')
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

for epoch in range(epoch_num):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0

for batch_id in range(100):
start_time = time.time()
img = paddle.to_tensor(
np.random.random([batch_size, 3, 224, 224]).astype('float32')
)
label = paddle.to_tensor(
np.random.randint(0, 100, [batch_size, 1], dtype='int64')
)
img.stop_gradient = True
label.stop_gradient = True

with paddle.amp.auto_cast():
pred = resnet(img)
# FIXME(Aurelius84): The following cross_entropy seems to bring out a
# precision problem, need to figure out the underlying reason.
# If we remove it, the loss between dygraph and dy2stat is exactly same.
loss = paddle.nn.functional.cross_entropy(
input=pred,
label=label,
reduction='none',
use_softmax=False,
)
img.stop_gradient = True
label.stop_gradient = True

with paddle.amp.auto_cast():
pred = resnet(img)
# FIXME(Aurelius84): The following cross_entropy seems to bring out a
# precision problem, need to figure out the underlying reason.
# If we remove it, the loss between dygraph and dy2stat is exactly same.
loss = paddle.nn.functional.cross_entropy(
input=pred,
label=label,
reduction='none',
use_softmax=False,
avg_loss = paddle.mean(x=pred)
acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5)

scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
resnet.clear_gradients()

total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1

end_time = time.time()
if batch_id % 2 == 0:
print(
"epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f"
% (
epoch,
batch_id,
total_loss.numpy() / total_sample,
total_acc1.numpy() / total_sample,
total_acc5.numpy() / total_sample,
end_time - start_time,
)
avg_loss = paddle.mean(x=pred)
acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5)

scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
resnet.clear_gradients()

total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1

end_time = time.time()
if batch_id % 2 == 0:
print(
"epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f"
% (
epoch,
batch_id,
total_loss.numpy() / total_sample,
total_acc1.numpy() / total_sample,
total_acc5.numpy() / total_sample,
end_time - start_time,
)
)
if batch_id == 10:
break
)
if batch_id == 10:
break

return total_loss.numpy()

Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def __init__(self, layers=50, class_dim=102):
),
)

@paddle.jit.to_static
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
Expand Down Expand Up @@ -280,7 +279,7 @@ def do_train(self, to_static):
dataset, batch_size=batch_size, drop_last=True
)

resnet = ResNet()
resnet = paddle.jit.to_static(ResNet())
optimizer = optimizer_setting(parameter_list=resnet.parameters())

for epoch in range(epoch_num):
Expand Down Expand Up @@ -339,7 +338,7 @@ def do_train(self, to_static):
def predict_dygraph(self, data):
paddle.jit.enable_to_static(False)
paddle.disable_static(place)
resnet = ResNet()
resnet = paddle.jit.to_static(ResNet())

model_dict = paddle.load(self.dy_state_dict_save_path + '.pdparams')
resnet.set_dict(model_dict)
Expand Down
Loading