Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0417a21
add eval n1 policy
KumoLiu Apr 16, 2025
a59ba39
add readme
KumoLiu Apr 16, 2025
66b0f89
Merge branch 'main' into yunl/n1-policy-eval
KumoLiu May 8, 2025
99e712b
add n1 policy and enhance readme
KumoLiu May 8, 2025
695f70b
minor fix
KumoLiu May 8, 2025
a5ec425
remove action in the data point
KumoLiu May 8, 2025
4f2df07
update test file
KumoLiu May 8, 2025
24f0546
address comments
KumoLiu May 8, 2025
db7c9f5
combine run policy and update readme
KumoLiu May 12, 2025
c12aedd
rename test_pi0 to test_policy
KumoLiu May 12, 2025
42c9004
update dead links
KumoLiu May 12, 2025
aa6ed30
add utils and runners for each policy
KumoLiu May 12, 2025
47f1c6e
fix format
KumoLiu May 12, 2025
eacdddd
update env setup support gr00t
KumoLiu May 13, 2025
bef6357
Merge remote-tracking branch 'origin/main' into yunl/n1-policy-eval
KumoLiu May 13, 2025
61df915
remove both and update core to base
KumoLiu May 13, 2025
b12693c
update use gr00tn1
KumoLiu May 13, 2025
bd19171
update readme
KumoLiu May 13, 2025
c274512
update readme
KumoLiu May 13, 2025
4d73b0a
address comments
KumoLiu May 14, 2025
8701387
update readme
KumoLiu May 14, 2025
5556bc2
replace all the GR00TN1 with GR00T N1
KumoLiu May 14, 2025
cd1fb54
exit if third party exist
KumoLiu May 14, 2025
9e3929f
update README
KumoLiu May 14, 2025
f5a98ea
update import
KumoLiu May 14, 2025
e9c1952
update for check
KumoLiu May 14, 2025
a13429c
strip policy
KumoLiu May 14, 2025
2e08398
update import
KumoLiu May 14, 2025
c39057e
Merge branch 'main' into yunl/n1-policy-eval
mingxin-zheng May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions workflows/robotic_ultrasound/scripts/policy_runner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,39 @@ Here's a markdown table describing the command-line arguments:
| NVIDIA RTX 4090 | 100 ms | 9 GB | 50 |

> **Note:** The model predicts the 50 next actions in a single 100ms inference, allowing you to choose how many of these predictions to utilize based on your specific control frequency requirements.


## Run GROOT N1 policy with DDS communication

### Prepare Model Weights and Dependencies

For now, please refer to the official [NVIDIA Isaac GR00T Installation Guide](https://github.com/NVIDIA/Isaac-GR00T?tab=readme-ov-file#installation-guide) for detailed instructions on setting up the necessary dependencies and acquiring model weights. This section will be updated with more specific instructions later.

### Ensure the PYTHONPATH Is Set

Make sure your environment variables are properly set as described in the [Environment Setup - Set environment variables before running the scripts](../../README.md#set-environment-variables-before-running-the-scripts) section.

### Run Policy

From the [`policy_runner` folder](./), execute:
```sh
python run_n1_policy.py
```


### Command Line Arguments

| Argument | Type | Default | Description |
|---------------------------|--------|--------------------------------|-----------------------------------------------------------------------------|
| `--ckpt_path` | str | `None` | Path to the GROOT N1 policy model checkpoint file. If not provided, uses a default path based on assets. |
| `--rti_license_file` | str | Environment `RTI_LICENSE_FILE` | Path to the RTI Connext DDS license file. |
| `--domain_id` | int | `0` | DDS domain ID for communication. |
| `--height` | int | `224` | Expected height of the input camera images. |
| `--width` | int | `224` | Expected width of the input camera images. |
| `--topic_in_room_camera` | str | `"topic_room_camera_data_rgb"` | DDS topic name for subscribing to room camera RGB images. |
| `--topic_in_wrist_camera` | str | `"topic_wrist_camera_data_rgb"`| DDS topic name for subscribing to wrist camera RGB images. |
| `--topic_in_franka_pos` | str | `"topic_franka_info"` | DDS topic name for subscribing to Franka robot joint states (positions). |
| `--topic_out` | str | `"topic_franka_ctrl"` | DDS topic name for publishing predicted Franka robot actions. |
| `--verbose` | bool | `False` | If set, enables verbose logging output. |
| `--data_config` | str | `"single_panda_us"` | Name of the data configuration to use (e.g., for modality and transforms). Choices are from `DATA_CONFIG_MAP`. |
| `--embodiment_tag` | str | `"new_embodiment"` | The embodiment tag for the GROOT model. |
178 changes: 178 additions & 0 deletions workflows/robotic_ultrasound/scripts/policy_runner/run_n1_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import numpy as np
import torch
from dds.publisher import Publisher
from dds.schemas.camera_info import CameraInfo
from dds.schemas.franka_ctrl import FrankaCtrlInput
from dds.schemas.franka_info import FrankaInfo
from dds.subscriber import SubscriberWithCallback
from gr00t.model.policy import BasePolicy, Gr00tPolicy
from policy_runner.utils import DATA_CONFIG_MAP

current_state = {
"room_cam": None,
"wrist_cam": None,
"joint_pos": None,
}


def main():
parser = argparse.ArgumentParser(description="Run the openpi0 policy runner")
parser.add_argument(
"--ckpt_path",
type=str,
help="checkpoint path. Default to use policy checkpoint in the latest assets.",
)
parser.add_argument(
"--rti_license_file", type=str, default=os.getenv("RTI_LICENSE_FILE"), help="the path of rti_license_file."
)
parser.add_argument("--domain_id", type=int, default=0, help="domain id.")
parser.add_argument("--height", type=int, default=224, help="input image height.")
parser.add_argument("--width", type=int, default=224, help="input image width.")
parser.add_argument(
"--topic_in_room_camera",
type=str,
default="topic_room_camera_data_rgb",
help="topic name to consume room camera rgb.",
)
parser.add_argument(
"--topic_in_wrist_camera",
type=str,
default="topic_wrist_camera_data_rgb",
help="topic name to consume wrist camera rgb.",
)
parser.add_argument(
"--topic_in_franka_pos",
type=str,
default="topic_franka_info",
help="topic name to consume franka pos.",
)
parser.add_argument(
"--topic_out",
type=str,
default="topic_franka_ctrl",
help="topic name to publish generated franka actions.",
)
parser.add_argument("--verbose", type=bool, default=False, help="whether to print the log.")
parser.add_argument(
"--data_config",
type=str,
default="single_panda_us",
choices=list(DATA_CONFIG_MAP.keys()),
help="data config name",
)
# parser.add_argument("--modality_keys", nargs="+", type=str, default=["panda_hand"])
parser.add_argument(
"--embodiment_tag",
type=str,
help="The embodiment tag for the model.",
default="new_embodiment",
)
args = parser.parse_args()

data_config = DATA_CONFIG_MAP[args.data_config]
modality_config = data_config.modality_config()
modality_transform = data_config.transform()

policy: BasePolicy = Gr00tPolicy(
model_path=args.ckpt_path,
modality_config=modality_config,
modality_transform=modality_transform,
embodiment_tag=args.embodiment_tag,
device="cuda" if torch.cuda.is_available() else "cpu",
)

if args.rti_license_file is not None:
if not os.path.isabs(args.rti_license_file):
raise ValueError("RTI license file must be an existing absolute path.")
os.environ["RTI_LICENSE_FILE"] = args.rti_license_file

hz = 30

class PolicyPublisher(Publisher):
def __init__(self, topic: str, domain_id: int):
super().__init__(topic, FrankaCtrlInput, 1 / hz, domain_id)

def produce(self, dt: float, sim_time: float):
# Process camera images directly to numpy arrays
def _process_camera_image(buffer, height, width):
img_buffer = np.frombuffer(buffer, dtype=np.uint8)
return img_buffer.reshape(height, width, 3)

# Get images from camera buffers
room_img = _process_camera_image(current_state["room_cam"], args.height, args.width)
wrist_img = _process_camera_image(current_state["wrist_cam"], args.height, args.width)
joint_pos = current_state["joint_pos"]

# Prepare input data with batch dimension for model
data_point = {
"video.room": np.expand_dims(room_img, axis=0),
"video.wrist": np.expand_dims(wrist_img, axis=0),
"state.panda_hand": np.expand_dims(np.array(joint_pos), axis=0),
"annotation.human.task_description": "Perform a liver ultrasound.",
}
actions = policy.get_action(data_point)
i = FrankaCtrlInput()
# actions are relative positions, if run with absolute positions, need to add the current joint positions
# actions shape is (16, 6), must reshape to (96,)
i.joint_positions = (
np.array(actions["action.panda_hand"])
.astype(np.float32)
.reshape(
16 * 6,
)
.tolist()
)
return i

writer = PolicyPublisher(args.topic_out, args.domain_id)

def dds_callback(topic, data):
if args.verbose:
print(f"[INFO]: Received data from {topic}")
if topic == args.topic_in_room_camera:
o: CameraInfo = data
current_state["room_cam"] = o.data

if topic == args.topic_in_wrist_camera:
o: CameraInfo = data
current_state["wrist_cam"] = o.data

if topic == args.topic_in_franka_pos:
o: FrankaInfo = data
current_state["joint_pos"] = o.joints_state_positions
if (
current_state["room_cam"] is not None
and current_state["wrist_cam"] is not None
and current_state["joint_pos"] is not None
):
writer.write()
if args.verbose:
print(f"[INFO]: Published joint position to {args.topic_out}")
# clean the buffer
current_state["room_cam"] = current_state["wrist_cam"] = current_state["joint_pos"] = None

SubscriberWithCallback(dds_callback, args.domain_id, args.topic_in_room_camera, CameraInfo, 1 / hz).start()
SubscriberWithCallback(dds_callback, args.domain_id, args.topic_in_wrist_camera, CameraInfo, 1 / hz).start()
SubscriberWithCallback(dds_callback, args.domain_id, args.topic_in_franka_pos, FrankaInfo, 1 / hz).start()


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions workflows/robotic_ultrasound/scripts/policy_runner/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ class PI0PolicyRunner:
Args:
ckpt_path: Path to the checkpoint file.
repo_id: Repository ID of the original training dataset.
task_description: Task description. Default is "Conduct a ultrasound scan on the liver."
task_description: Task description. Default is "Perform a liver ultrasound."

"""

def __init__(
self,
ckpt_path,
repo_id,
task_description="Conduct a ultrasound scan on the liver.",
task_description="Perform a liver ultrasound.",
):
config = get_config(name="robotic_ultrasound", repo_id=repo_id)
print(f"Loading model from {ckpt_path}...")
Expand Down
98 changes: 98 additions & 0 deletions workflows/robotic_ultrasound/scripts/policy_runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
import openpi.training.data_loader as _data_loader
import openpi.transforms as _transforms
import tqdm
from gr00t.data.dataset import ModalityConfig
from gr00t.data.transform.base import ComposedModalityTransform
from gr00t.data.transform.concat import ConcatTransform
from gr00t.data.transform.state_action import StateActionToTensor, StateActionTransform
from gr00t.data.transform.video import VideoColorJitter, VideoCrop, VideoResize, VideoToNumpy, VideoToTensor
from gr00t.experiment.data_config import DATA_CONFIG_MAP, BaseDataConfig
from gr00t.model.transforms import GR00TTransform
from openpi import transforms
from openpi.compute_norm_stats import create_dataset
from openpi.models import model as _model
Expand Down Expand Up @@ -175,3 +182,94 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
data_transforms=data_transforms,
model_transforms=model_transforms,
)


class SinglePandaUSDataConfig(BaseDataConfig):
video_keys = [
"video.room",
"video.wrist",
]
state_keys = [
"state.panda_hand",
]
action_keys = [
"action.panda_hand",
]

language_keys = ["annotation.human.task_description"]
observation_indices = [0]
action_indices = list(range(16))

def modality_config(self):
video_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.video_keys,
)
state_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.state_keys,
)
action_modality = ModalityConfig(
delta_indices=self.action_indices,
modality_keys=self.action_keys,
)
language_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.language_keys,
)
modality_configs = {
"video": video_modality,
"state": state_modality,
"action": action_modality,
"language": language_modality,
}
return modality_configs

def transform(self):
transforms = [
# video transforms
VideoToTensor(apply_to=self.video_keys),
VideoCrop(apply_to=self.video_keys, scale=0.95),
VideoResize(apply_to=self.video_keys, height=224, width=224, interpolation="linear"),
VideoColorJitter(
apply_to=self.video_keys,
brightness=0.3,
contrast=0.4,
saturation=0.5,
hue=0.08,
),
VideoToNumpy(apply_to=self.video_keys),
# state transforms
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(
apply_to=self.state_keys,
normalization_modes={
"state.panda_hand": "min_max",
},
),
# action transforms
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(
apply_to=self.action_keys,
normalization_modes={
"action.panda_hand": "min_max",
},
),
# concat transforms
ConcatTransform(
video_concat_order=self.video_keys,
state_concat_order=self.state_keys,
action_concat_order=self.action_keys,
),
GR00TTransform(
state_horizon=len(self.observation_indices),
action_horizon=len(self.action_indices),
max_state_dim=64,
max_action_dim=32,
),
]

return ComposedModalityTransform(transforms=transforms)


DATA_CONFIG_MAP["single_panda_us"] = SinglePandaUSDataConfig()
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@
parser.add_argument(
"--scale", type=float, default=1000.0, help="Scale factor to convert from omniverse to organ coordinate system."
)
parser.add_argument("--chunk_length", type=int, default=50, help="Length of the action chunk inferred by the policy.")

# append AppLauncher cli argruments
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
Expand Down Expand Up @@ -299,7 +301,7 @@ def main():
while ret is None:
ret = infer_reader.read_data()
o: FrankaCtrlInput = ret
action_chunk = np.array(o.joint_positions, dtype=np.float32).reshape(50, 6)
action_chunk = np.array(o.joint_positions, dtype=np.float32).reshape(args_cli.chunk_length, 6)
action_plan.extend(action_chunk[:replan_steps])

action = action_plan.popleft()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def main():
policy_runner = PI0PolicyRunner(
ckpt_path=args_cli.ckpt_path,
repo_id=args_cli.repo_id,
task_description="Conduct a ultrasound scan on the liver.",
task_description="Perform a liver ultrasound.",
)
# Number of steps played before replanning
replan_steps = 5
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2685, -0.2822, -0.1428, 0.5501, -0.0998, -0.6649, 0.1716
0.2284, -0.7496, -0.2118, -2.4747, -0.1321, 1.7308, 0.0885
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading