@@ -151,11 +151,6 @@ def test_build_model_from_intermediate_tensor(self):
151151 model .fit (
152152 np .random .randn (batch_size , 32 ), np .random .randn (batch_size , 16 )
153153 )
154- # Test for model saving
155- output_path = os .path .join (self .get_temp_dir (), "tf_keras_saved_model" )
156- model .save (output_path , save_format = "tf" )
157- loaded_model = models .load_model (output_path )
158- self .assertEqual (model .summary (), loaded_model .summary ())
159154
160155 # Also make sure the original inputs and y can still be used to build
161156 # model
@@ -167,6 +162,27 @@ def test_build_model_from_intermediate_tensor(self):
167162 self .assertIs (new_model .layers [1 ], layer1 )
168163 self .assertIs (new_model .layers [2 ], layer2 )
169164
165+ # Test for model saving
166+ with self .subTest ("savedmodel" ):
167+ output_path = os .path .join (
168+ self .get_temp_dir (), "tf_keras_saved_model"
169+ )
170+ model .save (output_path , save_format = "tf" )
171+ loaded_model = models .load_model (output_path )
172+ self .assertEqual (model .summary (), loaded_model .summary ())
173+
174+ with self .subTest ("keras_v3" ):
175+ if not tf .__internal__ .tf2 .enabled ():
176+ self .skipTest (
177+ "TF2 must be enabled to use the new `.keras` saving."
178+ )
179+ output_path = os .path .join (
180+ self .get_temp_dir (), "tf_keras_v3_model.keras"
181+ )
182+ model .save (output_path , save_format = "keras_v3" )
183+ loaded_model = models .load_model (output_path )
184+ self .assertEqual (model .summary (), loaded_model .summary ())
185+
170186 def test_build_model_from_intermediate_tensor_with_complicated_model (self ):
171187 # The topology is like below:
172188 # input1 -> dense1 -> a
@@ -212,10 +228,6 @@ def test_build_model_from_intermediate_tensor_with_complicated_model(self):
212228 ],
213229 np .random .randn (batch_size , 8 ),
214230 )
215- output_path = os .path .join (self .get_temp_dir (), "tf_keras_saved_model" )
216- model .save (output_path , save_format = "tf" )
217- loaded_model = models .load_model (output_path )
218- self .assertEqual (model .summary (), loaded_model .summary ())
219231
220232 model2 = models .Model ([a , b ], d )
221233 # 2 input layers and 2 Add layer.
@@ -230,6 +242,26 @@ def test_build_model_from_intermediate_tensor_with_complicated_model(self):
230242 np .random .randn (batch_size , 8 ),
231243 )
232244
245+ with self .subTest ("savedmodel" ):
246+ output_path = os .path .join (
247+ self .get_temp_dir (), "tf_keras_saved_model"
248+ )
249+ model .save (output_path , save_format = "tf" )
250+ loaded_model = models .load_model (output_path )
251+ self .assertEqual (model .summary (), loaded_model .summary ())
252+
253+ with self .subTest ("keras_v3" ):
254+ if not tf .__internal__ .tf2 .enabled ():
255+ self .skipTest (
256+ "TF2 must be enabled to use the new `.keras` saving."
257+ )
258+ output_path = os .path .join (
259+ self .get_temp_dir (), "tf_keras_v3_model.keras"
260+ )
261+ model .save (output_path , save_format = "keras_v3" )
262+ loaded_model = models .load_model (output_path )
263+ self .assertEqual (model .summary (), loaded_model .summary ())
264+
233265
234266if __name__ == "__main__" :
235267 tf .test .main ()
0 commit comments