Skip to content
35 changes: 25 additions & 10 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,10 +1191,10 @@ def onnx_export(


def ckpt_export(
net_id: str | None = None,
filepath: PathLike | None = None,
ckpt_file: str | None = None,
meta_file: str | Sequence[str] | None = None,
net_id: str | None = "network_def",
Comment thread
wyli marked this conversation as resolved.
Outdated
filepath: PathLike | None = "models/model.ts",
Comment thread
wyli marked this conversation as resolved.
Outdated
ckpt_file: str | None = "models/model.pt",
meta_file: str | Sequence[str] | None = "configs/metadata.json",
config_file: str | Sequence[str] | None = None,
key_in_ckpt: str | None = None,
use_trace: bool | None = None,
Expand Down Expand Up @@ -1250,9 +1250,10 @@ def ckpt_export(
)
_log_input_summary(tag="ckpt_export", args=_args)
(
config_file_,
filepath_,
ckpt_file_,
config_file_,
bundle_root_,
net_id_,
meta_file_,
key_in_ckpt_,
Expand All @@ -1261,11 +1262,12 @@ def ckpt_export(
converter_kwargs_,
) = _pop_args(
_args,
"filepath",
"ckpt_file",
"config_file",
net_id="",
meta_file=None,
filepath="models/model.ts",
ckpt_file="models/model.pt",
bundle_root=os.getcwd(),
net_id="network_def",
meta_file="configs/metadata.json",
key_in_ckpt="",
use_trace=False,
input_shape=None,
Expand All @@ -1275,9 +1277,22 @@ def ckpt_export(
parser = ConfigParser()

parser.read_config(f=config_file_)
if meta_file_ is not None:
meta_file_ = (
os.path.join(bundle_root_, "configs/metadata.json") if meta_file_ == "configs/metadata.json" else meta_file_
)
filepath_ = os.path.join(bundle_root_, "models/model.ts") if filepath_ == "models/model.ts" else filepath_
ckpt_file_ = os.path.join(bundle_root_, "models/model.pt") if ckpt_file_ == "models/model.pt" else ckpt_file_
if not os.path.exists(ckpt_file_):
raise FileNotFoundError(f"ckpt_file in {ckpt_file_} does not exist, please specify it.")
Comment thread
KumoLiu marked this conversation as resolved.
Outdated
if os.path.exists(meta_file_):
parser.read_meta(f=meta_file_)

if net_id_ == "network_def":
try:
parser.get_parsed_content(net_id_)
except ValueError as e:
raise ValueError(f"Default net_id: network_def in {config_file_} does not exist.") from e

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v
Expand Down
53 changes: 38 additions & 15 deletions tests/test_bundle_ckpt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,61 @@ def tearDown(self):
else:
del os.environ["CUDA_VISIBLE_DEVICES"] # previously unset

# @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
# def test_export(self, key_in_ckpt, use_trace):
# meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json")
# config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
# with tempfile.TemporaryDirectory() as tempdir:
# def_args = {"meta_file": "will be replaced by `meta_file` arg"}
# def_args_file = os.path.join(tempdir, "def_args.yaml")

# ckpt_file = os.path.join(tempdir, "model.pt")
# ts_file = os.path.join(tempdir, "model.ts")

# parser = ConfigParser()
# parser.export_config_file(config=def_args, filepath=def_args_file)
# parser.read_config(config_file)
# net = parser.get_parsed_content("network_def")
# save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

# cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
# cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
# cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
# if use_trace == "True":
# cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
# command_line_tests(cmd)
# self.assertTrue(os.path.exists(ts_file))

# _, metadata, extra_files = load_net_with_metadata(
# ts_file, more_extra_files=["inference.json", "def_args.json"]
# )
# self.assertTrue("schema" in metadata)
# self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
# self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_export(self, key_in_ckpt, use_trace):
meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json")
def test_default_value(self, key_in_ckpt, use_trace):
config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
def_args_file = os.path.join(tempdir, "def_args.yaml")

ckpt_file = os.path.join(tempdir, "model.pt")
ts_file = os.path.join(tempdir, "model.ts")
ckpt_file = os.path.join(tempdir, "models/model.pt")
ts_file = os.path.join(tempdir, "models/model.ts")

parser = ConfigParser()
parser.export_config_file(config=def_args, filepath=def_args_file)
parser.read_config(config_file)
net = parser.get_parsed_content("network_def")
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)

cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
# check with default value
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
cmd += ["--config_file", config_file, "--bundle_root", tempdir]
if use_trace == "True":
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
command_line_tests(cmd)
self.assertTrue(os.path.exists(ts_file))

_, metadata, extra_files = load_net_with_metadata(
ts_file, more_extra_files=["inference.json", "def_args.json"]
)
self.assertTrue("schema" in metadata)
self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))


if __name__ == "__main__":
unittest.main()