1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import gc
1516import os
1617
1718import numpy as np
2324from keras_cv .backend import ops
2425from keras_cv .backend .config import keras_3
2526from keras_cv .models import BASNet
26- from keras_cv .models import ResNet34Backbone
27+ from keras_cv .models import ResNet18Backbone
2728from keras_cv .tests .test_case import TestCase
2829
2930
3031class BASNetTest (TestCase ):
3132 def test_basnet_construction (self ):
32- backbone = ResNet34Backbone ()
33+ backbone = ResNet18Backbone ()
3334 model = BASNet (
3435 input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
3536 )
@@ -41,7 +42,7 @@ def test_basnet_construction(self):
4142
4243 @pytest .mark .large
4344 def test_basnet_call (self ):
44- backbone = ResNet34Backbone ()
45+ backbone = ResNet18Backbone ()
4546 model = BASNet (
4647 input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
4748 )
@@ -61,7 +62,7 @@ def test_weights_change(self):
6162 ds = ds .repeat (2 )
6263 ds = ds .batch (2 )
6364
64- backbone = ResNet34Backbone ()
65+ backbone = ResNet18Backbone ()
6566 model = BASNet (
6667 input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
6768 )
@@ -99,7 +100,7 @@ def test_with_model_preset_forward_pass(self):
99100 def test_saved_model (self ):
100101 target_size = [288 , 288 , 3 ]
101102
102- backbone = ResNet34Backbone ()
103+ backbone = ResNet18Backbone ()
103104 model = BASNet (
104105 input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
105106 )
@@ -112,6 +113,9 @@ def test_saved_model(self):
112113 model .save (save_path )
113114 else :
114115 model .save (save_path , save_format = "keras_v3" )
116+ # Free up model memory
117+ del model
118+ gc .collect ()
115119 restored_model = keras .models .load_model (save_path )
116120
117121 # Check we got the real object back.
0 commit comments