Skip to content

Commit 2d647f1

Browse files
committed
issue #22: [fix] avoid TypeError by branching the processing based on the variable's type
1 parent d536384 commit 2d647f1

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

pybullet_tools/pr2_primitives.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import random
66
import time
77
from itertools import islice, count
8+
from typing import Tuple, Optional, Union
89

910
import numpy as np
1011

@@ -193,15 +194,23 @@ def create_trajectory(robot, joints, path):
193194
##################################################
194195

195196
class GripperCommand(Command):
196-
def __init__(self, robot, arm, position, teleport=False):
197+
def __init__(self, robot, arm, position: Union[float, Tuple[float, ...]], teleport=False):
197198
self.robot = robot
198199
self.arm = arm
199-
self.position = position
200+
self.position: Union[float, Tuple[float, ...]] = position
200201
self.teleport = teleport
201202
def apply(self, state, **kwargs):
202203
joints = get_gripper_joints(self.robot, self.arm)
203204
start_conf = get_joint_positions(self.robot, joints)
204-
end_conf = [self.position] * len(joints)
205+
206+
# NOTE: position can be a single float (applied to all joints) or an array-like object of floats (one per joint)
207+
if isinstance(self.position, float):
208+
end_conf = [self.position] * len(joints)
209+
elif isinstance(self.position, (list, tuple, np.ndarray)) and len(self.position) == len(joints):
210+
end_conf = self.position
211+
else:
212+
raise ValueError(f'Invalid gripper position: {self.position}')
213+
205214
if self.teleport:
206215
path = [start_conf, end_conf]
207216
else:
@@ -328,7 +337,7 @@ def fn(body):
328337
for g in get_side_grasps(body, grasp_length=GRASP_LENGTH))
329338
filtered_grasps = []
330339
for grasp in grasps:
331-
grasp_width = compute_grasp_width(problem.robot, arm, body, grasp.value) if collisions else 0.0
340+
grasp_width: Optional[Union[float, Tuple[float, ...]]] = compute_grasp_width(problem.robot, arm, body, grasp.value) if collisions else 0.0
332341
if grasp_width is not None:
333342
grasp.grasp_width = grasp_width
334343
filtered_grasps.append(grasp)

pybullet_tools/pr2_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
from collections import namedtuple
66
from itertools import combinations
7+
from typing import List, Optional, Tuple
78

89
import numpy as np
910

@@ -17,7 +18,7 @@
1718
movable_from_joints, quat_from_axis_angle, LockRenderer, Euler, get_links, get_link_name, \
1819
get_extend_fn, get_moving_links, link_pairs_collision, get_link_subtree, \
1920
clone_body, get_all_links, pairwise_collision, tform_point, get_camera_matrix, ray_from_pixel, pixel_from_ray, dimensions_from_camera_matrix, \
20-
wrap_angle, TRANSPARENT, PI, OOBB, pixel_from_point, set_all_color, wait_if_gui
21+
wrap_angle, TRANSPARENT, PI, OOBB, pixel_from_point, set_all_color, wait_if_gui, POSE
2122

2223
# TODO: restrict number of pr2 rotations to prevent from wrapping too many times
2324

@@ -291,7 +292,7 @@ def close_arm(robot, arm):
291292
SIDE_HEIGHT_OFFSET = 0.03 # z distance from top of object
292293

293294
def get_top_grasps(body, under=False, tool_pose=TOOL_POSE, body_pose=unit_pose(),
294-
max_width=MAX_GRASP_WIDTH, grasp_length=GRASP_LENGTH):
295+
max_width=MAX_GRASP_WIDTH, grasp_length=GRASP_LENGTH) -> List[POSE]:
295296
# TODO: rename the box grasps
296297
center, (w, l, h) = approximate_as_prism(body, body_pose=body_pose)
297298
reflect_z = Pose(euler=[0, math.pi, 0])
@@ -757,7 +758,9 @@ def get_base_extend_fn(robot):
757758

758759
#####################################
759760

760-
def close_until_collision(robot, gripper_joints, bodies=[], open_conf=None, closed_conf=None, num_steps=25, **kwargs):
761+
def close_until_collision(
762+
robot, gripper_joints, bodies=[], open_conf: List[float] = None, closed_conf: List[float] = None, num_steps=25, **kwargs
763+
) -> Optional[Tuple[float, ...]]:
761764
if not gripper_joints:
762765
return None
763766
if open_conf is None:
@@ -766,7 +769,7 @@ def close_until_collision(robot, gripper_joints, bodies=[], open_conf=None, clos
766769
closed_conf = [get_min_limit(robot, joint) for joint in gripper_joints]
767770
resolutions = np.abs(np.array(open_conf) - np.array(closed_conf)) / num_steps
768771
extend_fn = get_extend_fn(robot, gripper_joints, resolutions=resolutions)
769-
close_path = [open_conf] + list(extend_fn(open_conf, closed_conf))
772+
close_path: List[Tuple[float, ...]] = [open_conf] + list(extend_fn(open_conf, closed_conf))
770773
collision_links = frozenset(get_moving_links(robot, gripper_joints))
771774

772775
for i, conf in enumerate(close_path):
@@ -778,7 +781,7 @@ def close_until_collision(robot, gripper_joints, bodies=[], open_conf=None, clos
778781
return close_path[-1]
779782
#return None # False
780783

781-
def compute_grasp_width(robot, arm, body, grasp_pose, **kwargs):
784+
def compute_grasp_width(robot, arm, body, grasp_pose, **kwargs) -> Optional[Tuple[float, ...]]:
782785
tool_link = get_gripper_link(robot, arm)
783786
tool_pose = get_link_pose(robot, tool_link)
784787
body_pose = multiply(tool_pose, grasp_pose)

pybullet_tools/utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@
2626
from itertools import product, combinations, count, cycle, islice
2727
from multiprocessing import TimeoutError
2828
from contextlib import contextmanager
29+
from typing import List, Tuple
2930

3031
from pybullet_utils.transformations import quaternion_from_matrix, unit_vector, euler_from_quaternion, quaternion_slerp, \
3132
random_quaternion, quaternion_about_axis
3233

34+
VEC3 = Tuple[float, float, float]
35+
QUATERNION = Tuple[float, float, float, float]
36+
POSE = Tuple[VEC3, QUATERNION]
37+
3338
def join_paths(*paths):
3439
return os.path.abspath(os.path.join(*paths))
3540

@@ -118,8 +123,9 @@ def get_parent_dir(file): # __file__
118123

119124
inf_generator = count # count | lambda: iter(int, 1)
120125

121-
List = lambda *args: list(args)
122-
Tuple = lambda *args: tuple(args)
126+
# Disable these lines because they prevent developers from using type hints
127+
# List = lambda *args: list(args)
128+
# Tuple = lambda *args: tuple(args)
123129

124130
def empty_sequence():
125131
return iter([])
@@ -1619,26 +1625,26 @@ def Point(x=0., y=0., z=0.):
16191625
def Euler(roll=0., pitch=0., yaw=0.):
16201626
return np.array([roll, pitch, yaw])
16211627

1622-
def Pose(point=None, euler=None):
1628+
def Pose(point=None, euler=None) -> POSE:
16231629
point = Point() if point is None else point
16241630
euler = Euler() if euler is None else euler
16251631
return point, quat_from_euler(euler)
16261632

16271633
def Pose2d(x=0., y=0., yaw=0.):
16281634
return np.array([x, y, yaw])
16291635

1630-
def invert(pose):
1636+
def invert(pose: POSE) -> POSE:
16311637
point, quat = pose
16321638
return p.invertTransform(point, quat)
16331639

1634-
def multiply(*poses):
1635-
pose = poses[0]
1640+
def multiply(*poses) -> POSE:
1641+
pose: POSE = poses[0]
16361642
for next_pose in poses[1:]:
16371643
pose = p.multiplyTransforms(pose[0], pose[1], *next_pose)
16381644
return pose
16391645

1640-
def invert_quat(quat):
1641-
pose = (unit_point(), quat)
1646+
def invert_quat(quat: QUATERNION) -> QUATERNION:
1647+
pose: POSE = (unit_point(), quat)
16421648
return quat_from_pose(invert(pose))
16431649

16441650
def multiply_quats(*quats):
@@ -2162,17 +2168,17 @@ def get_joint_limits(body, joint):
21622168

21632169
get_joint_interval = get_joint_limits # TODO: get box limits?
21642170

2165-
def get_min_limit(body, joint):
2171+
def get_min_limit(body, joint) -> float:
21662172
# TODO: rename to min_position
21672173
return get_joint_limits(body, joint)[0]
21682174

2169-
def get_min_limits(body, joints):
2175+
def get_min_limits(body, joints) -> List[float]:
21702176
return [get_min_limit(body, joint) for joint in joints]
21712177

2172-
def get_max_limit(body, joint):
2178+
def get_max_limit(body, joint) -> float:
21732179
return get_joint_limits(body, joint)[1]
21742180

2175-
def get_max_limits(body, joints):
2181+
def get_max_limits(body, joints) -> List[float]:
21762182
return [get_max_limit(body, joint) for joint in joints]
21772183

21782184
def get_joint_intervals(body, joints):
@@ -4626,7 +4632,7 @@ def end_effector_from_body(body_pose, grasp_pose):
46264632
def approach_from_grasp(approach_pose, end_effector_pose):
46274633
return multiply(approach_pose, end_effector_pose)
46284634

4629-
def get_grasp_pose(constraint):
4635+
def get_grasp_pose(constraint) -> POSE:
46304636
"""
46314637
Grasps are parent_from_child
46324638
"""

0 commit comments

Comments
 (0)