Skip to content

Commit bfb2b59

Browse files
committed
using CasADi's serialization for rollout function
1 parent 79021d6 commit bfb2b59

2 files changed

Lines changed: 35 additions & 13 deletions

File tree

open-codegen/opengen/ocp/builder.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ def __initialize_from_metadata(self, metadata):
9191
)
9292
self.__input_slices = metadata.get("input_slices")
9393
self.__state_slices = metadata.get("state_slices")
94-
rollout_serialized = metadata.get("rollout_function")
95-
if rollout_serialized is not None:
96-
self.__rollout_function = cs.Function.deserialize(rollout_serialized)
94+
self.__rollout_function = metadata.get("rollout_function")
9795

9896
@property
9997
def target_dir(self):
@@ -192,19 +190,33 @@ def __metadata_dict(self):
192190
],
193191
"input_slices": self.__input_slices,
194192
"state_slices": self.__state_slices,
195-
"rollout_function": None
196-
if self.__rollout_function is None
197-
else self.__rollout_function.serialize(),
193+
"rollout_file": "rollout.casadi" if self.__rollout_function is not None else None,
198194
}
199195

200-
def save(self, json_path):
201-
"""Save a JSON manifest that can later recreate this optimizer.
196+
def save(self, json_path=None):
197+
"""Save a manifest that can later recreate this optimizer.
202198
203-
:param json_path: destination manifest path
199+
The manifest is stored in the generated optimizer directory by default.
200+
For single-shooting optimizers, the rollout function is stored in a
201+
separate ``rollout.casadi`` file next to the manifest.
202+
203+
:param json_path: optional destination manifest path
204204
:return: current instance
205205
"""
206+
if json_path is None:
207+
json_path = os.path.join(self.__target_dir, "optimizer_manifest.json")
208+
209+
json_path = os.path.abspath(json_path)
210+
manifest_dir = os.path.dirname(json_path)
211+
os.makedirs(manifest_dir, exist_ok=True)
212+
213+
metadata = self.__metadata_dict()
214+
rollout_file = metadata.get("rollout_file")
215+
if rollout_file is not None:
216+
self.__rollout_function.save(os.path.join(manifest_dir, rollout_file))
217+
206218
with open(json_path, "w") as fh:
207-
json.dump(self.__metadata_dict(), fh, indent=2)
219+
json.dump(metadata, fh, indent=2)
208220
return self
209221

210222
@staticmethod
@@ -228,8 +240,14 @@ def load(cls, json_path):
228240
:param json_path: path to a JSON manifest created by :meth:`save`
229241
:return: reconstructed :class:`GeneratedOptimizer`
230242
"""
243+
json_path = os.path.abspath(json_path)
231244
with open(json_path, "r") as fh:
232245
metadata = json.load(fh)
246+
rollout_file = metadata.get("rollout_file")
247+
if rollout_file is not None:
248+
metadata["rollout_function"] = cs.Function.load(
249+
os.path.join(os.path.dirname(json_path), rollout_file)
250+
)
233251
backend = cls.__load_backend(
234252
metadata["target_dir"],
235253
metadata["optimizer_name"],

open-codegen/test/test_ocp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,6 @@ def test_generated_optimizer_defaults(self):
325325

326326
def test_optimizer_manifest_roundtrip(self):
327327
optimizer_name = "ocp_manifest_bindings"
328-
manifest_path = os.path.join(OcpTestCase.TEST_DIR, f"{optimizer_name}.json")
329-
330328
ocp = og.ocp.OptimalControlProblem(nx=2, nu=1, horizon=3)
331329
ocp.add_parameter("x0", 2)
332330
ocp.add_parameter("xref", 2, default=[0.0, 0.0])
@@ -356,7 +354,13 @@ def test_optimizer_manifest_roundtrip(self):
356354
.with_max_outer_iterations(10),
357355
).build()
358356

359-
optimizer.save(manifest_path)
357+
optimizer.save()
358+
manifest_path = os.path.join(optimizer.target_dir, "optimizer_manifest.json")
359+
rollout_path = os.path.join(optimizer.target_dir, "rollout.casadi")
360+
361+
self.assertTrue(os.path.exists(manifest_path))
362+
self.assertTrue(os.path.exists(rollout_path))
363+
360364
loaded_optimizer = og.ocp.GeneratedOptimizer.load(manifest_path)
361365

362366
result = loaded_optimizer.solve(

0 commit comments

Comments
 (0)