Skip to content
106 changes: 70 additions & 36 deletions tools/run_all_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,44 +27,65 @@


XVFB_TEST_CASES = [
"test_orientation",
"test_visualization",
]


def get_tests(test_root):
path = f"{test_root}/**/test_*.py"
def get_tests(test_root, pattern="test_*.py"):
path = f"{test_root}/**/{pattern}"
return glob.glob(path, recursive=True)


def _run_test_process(cmd, env, test_path):
"""Helper function to run a test process and handle its output"""
print(f"Running test: {test_path}")
process = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate()

# Filter out extension loading messages
filtered_stdout = "\n".join([line for line in stdout.split("\n") if not ("[ext:" in line and "startup" in line)])
filtered_stderr = "\n".join([line for line in stderr.split("\n") if not ("[ext:" in line and "startup" in line)])

# Print filtered output
if filtered_stdout.strip():
print(filtered_stdout)
if filtered_stderr.strip():
print(filtered_stderr)

return process.returncode == 0


def _setup_test_env(project_root, tests_dir):
"""Helper function to setup test environment"""
env = os.environ.copy()
pythonpath = [os.path.join(project_root, "scripts"), tests_dir]

if "PYTHONPATH" in env:
env["PYTHONPATH"] = ":".join(pythonpath) + ":" + env["PYTHONPATH"]
else:
env["PYTHONPATH"] = ":".join(pythonpath)

return env


def run_tests_with_coverage(workflow_name):
"""Run all unittest cases with coverage reporting"""
project_root = f"workflows/{workflow_name}"

try:
# TODO: add license file to secrets
default_license_file = os.path.join(os.getcwd(), project_root, "scripts", "dds", "rti_license.dat")
os.environ["RTI_LICENSE_FILE"] = os.environ.get("RTI_LICENSE_FILE", default_license_file)
all_tests_passed = True
tests_dir = os.path.join(project_root, "tests")
print(f"Looking for tests in {tests_dir}")
tests = get_tests(tests_dir)
env = _setup_test_env(project_root, tests_dir)

for test_path in tests:
test_name = os.path.basename(test_path).replace(".py", "")
print(f"\nRunning test: {test_path}")

# add project root to pythonpath
env = os.environ.copy()
pythonpath = [os.path.join(project_root, "scripts"), tests_dir]

if "PYTHONPATH" in env:
env["PYTHONPATH"] = ":".join(pythonpath) + ":" + env["PYTHONPATH"]
else:
env["PYTHONPATH"] = ":".join(pythonpath)

# Check if this test needs a virtual display
if test_name in XVFB_TEST_CASES: # virtual display for GUI tests
if test_name in XVFB_TEST_CASES:
cmd = [
"xvfb-run",
"-a",
Expand All @@ -77,9 +98,11 @@ def run_tests_with_coverage(workflow_name):
"unittest",
test_path,
]
# TODO: remove this as integration tests
# TODO: move these tests to integration tests
elif "test_sim_with_dds" in test_path or "test_pi0" in test_path:
continue
elif "test_integration" in test_path:
continue
else:
cmd = [
sys.executable,
Expand All @@ -92,25 +115,7 @@ def run_tests_with_coverage(workflow_name):
test_path,
]

process = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate()

# Filter out extension loading messages
filtered_stdout = "\n".join(
[line for line in stdout.split("\n") if not ("[ext:" in line and "startup" in line)]
)
filtered_stderr = "\n".join(
[line for line in stderr.split("\n") if not ("[ext:" in line and "startup" in line)]
)

# Print filtered output
if filtered_stdout.strip():
print(filtered_stdout)
if filtered_stderr.strip():
print(filtered_stderr)

result = process
if result.returncode != 0:
if not _run_test_process(cmd, env, test_path):
all_tests_passed = False

# combine coverage results
Expand All @@ -137,13 +142,42 @@ def run_tests_with_coverage(workflow_name):
return 1


def run_integration_tests(workflow_name):
"""Run integration tests for a workflow"""
project_root = f"workflows/{workflow_name}"
default_license_file = os.path.join(os.getcwd(), project_root, "scripts", "dds", "rti_license.dat")
os.environ["RTI_LICENSE_FILE"] = os.environ.get("RTI_LICENSE_FILE", default_license_file)
all_tests_passed = True
tests_dir = os.path.join(project_root, "tests")
print(f"Looking for tests in {tests_dir}")
tests = get_tests(tests_dir, pattern="test_integration_*.py")
env = _setup_test_env(project_root, tests_dir)

for test_path in tests:
cmd = [
sys.executable,
"-m",
"unittest",
test_path,
]

if not _run_test_process(cmd, env, test_path):
all_tests_passed = False

return 0 if all_tests_passed else 1


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run all tests for a workflow")
parser.add_argument("--workflow", type=str, default="robotic_ultrasound", help="Workflow name")
parser.add_argument("--integration", action="store_true", help="Run integration tests")
args = parser.parse_args()

if args.workflow not in WORKFLOWS:
raise ValueError(f"Invalid workflow name: {args.workflow}")

exit_code = run_tests_with_coverage(args.workflow)
if args.integration:
exit_code = run_integration_tests(args.workflow)
else:
exit_code = run_tests_with_coverage(args.workflow)
sys.exit(exit_code)
5 changes: 1 addition & 4 deletions workflows/robotic_ultrasound/scripts/simulation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ export PYTHONPATH=`pwd`
5. Return to this folder and run the following command:

```sh
python environments/state_machine/pi0_policy/eval.py \
--task Isaac-Teleop-Torso-FrankaUsRs-IK-RL-Rel-v0 \
--enable_camera \
--repo_id i4h/sim_liver_scan
python environments/state_machine/pi0_policy/eval.py --enable_camera
```

NOTE: You can also specify `--ckpt_path` to run a specific policy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,24 @@
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument(
"--task",
type=str,
default="Isaac-Teleop-Torso-FrankaUsRs-IK-RL-Rel-v0",
help="Name of the task.",
)
parser.add_argument(
"--ckpt_path",
type=str,
default=robot_us_assets.policy_ckpt,
help="checkpoint path. Default to use policy checkpoint in the latest assets.",
)
parser.add_argument("--repo_id", type=str, help="the LeRobot repo id for the dataset norm.")
parser.add_argument(
"--repo_id",
type=str,
default="i4h/sim_liver_scan",
help="the LeRobot repo id for the dataset norm.",
)

# append AppLauncher cli argr
AppLauncher.add_app_launcher_args(parser)
Expand Down Expand Up @@ -142,6 +152,7 @@ def main():
obs, rew, terminated, truncated, info_ = env.step(action)

env.reset()
print("Resetting the environment.")
for _ in range(reset_steps):
reset_tensor = get_reset_action(env)
obs, rew, terminated, truncated, info_ = env.step(reset_tensor)
Expand Down
103 changes: 103 additions & 0 deletions workflows/robotic_ultrasound/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,112 @@
# limitations under the License.

import os
import signal
import subprocess
import threading
import time
from unittest import skipUnless


def requires_rti(func):
RTI_AVAILABLE = bool(os.getenv("RTI_LICENSE_FILE") and os.path.exists(os.getenv("RTI_LICENSE_FILE")))
return skipUnless(RTI_AVAILABLE, "RTI Connext DDS is not installed or license not found")(func)


def monitor_output(process, found_event, target_line=None):
"""Monitor process output for target_line and set event when found."""
try:
if target_line:
for line in iter(process.stdout.readline, ""):
if target_line in line:
found_event.set()
break # TODO: should we force the process to exit here?
except (ValueError, IOError):
# Handle case where stdout is closed
pass


def run_with_monitoring(command, timeout_seconds, target_line=None):
# Start the process with pipes for output
env = os.environ.copy()
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout
text=True,
bufsize=1, # Line buffered
preexec_fn=os.setsid if os.name != "nt" else None, # Create a new process group on Unix
env=env,
)

# Event to signal when target line is found
found_event = threading.Event()

# Start monitoring thread
monitor_thread = threading.Thread(target=monitor_output, args=(process, found_event, target_line))
monitor_thread.daemon = True
monitor_thread.start()

target_found = False

try:
# Wait for either timeout or target line found
start_time = time.time()
while time.time() - start_time < timeout_seconds:
if target_line and found_event.is_set():
target_found = True

# Check if process has already terminated
if process.poll() is not None:
break

time.sleep(0.1)

# If we get here, either timeout occurred or process ended
if process.poll() is None: # Process is still running
print(f"Sending SIGINT after {timeout_seconds} seconds...")

if os.name != "nt": # Unix/Linux/MacOS
# Send SIGINT to the entire process group
os.killpg(os.getpgid(process.pid), signal.SIGINT)
else: # Windows
process.send_signal(signal.CTRL_C_EVENT)

# Give the process some time to handle the signal and exit gracefully
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
print("Process didn't terminate after SIGINT, force killing...")
if os.name != "nt": # Unix/Linux/MacOS
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
else: # Windows
process.kill()

except Exception as e:
print(f"Error during process execution: {e}")
if process.poll() is None:
process.kill()

finally:
# Ensure we close all pipes and terminate the process
try:
# Try to get any remaining output, but with a short timeout
remaining_output, _ = process.communicate(timeout=2)
if remaining_output:
print(remaining_output)
except subprocess.TimeoutExpired:
# If communicate times out, force kill the process
process.kill()
process.communicate()

# If the process is somehow still running, make sure it's killed
if process.poll() is None:
process.kill()
process.communicate()

# Check if target was found
if not target_found and found_event.is_set():
target_found = True

return process.returncode, target_found
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from helpers import run_with_monitoring
from parameterized import parameterized

SM_CASES = [
(
"python -u -m simulation.environments.state_machine.pi0_policy.eval --enable_camera --headless",
120,
"Resetting the environment.",
),
]


class TestPolicyEval(unittest.TestCase):
@parameterized.expand(SM_CASES)
def test_policy_eval(self, command, timeout, target_line):
# Run and monitor command
_, found_target = run_with_monitoring(command, timeout, target_line)
self.assertTrue(found_target)


if __name__ == "__main__":
unittest.main()
Loading