Skip to content

Commit e8acd19

Browse files
committed
fixing errors
1 parent c69ca7a commit e8acd19

2 files changed

Lines changed: 16 additions & 13 deletions

File tree

toqito/matrix_ops/partial_trace.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,19 @@ def partial_trace(
143143
if n % dim != 0:
144144
raise ValueError("Invalid: If `dim` is a scalar, it must evenly divide matrix dimension.")
145145
dim = np.array([dim, n // dim])
146-
elif isinstance(dim, list):
146+
elif isinstance(dim, (list, tuple, np.ndarray)):
147+
dim = np.array(dim)
148+
if dim.ndim != 1:
149+
raise ValueError("Invalid: `dim` must be a 1D array-like of ints.")
147150
if len(dim) == 1:
148151
d = dim[0]
149152
if n % d != 0:
150-
raise ValueError("Invalid: If `dim` is a scalar, it must evenly divide matrix dimension.")
153+
raise ValueError(
154+
"Invalid: If `dim` is a scalar, it must evenly divide matrix dimension."
155+
)
151156
dim = np.array([d, n // d])
152-
else:
153-
dim = np.array(dim)
154157
else:
155-
raise ValueError("Invalid: `dim` must be int or list of ints.")
158+
raise ValueError("Invalid: `dim` must be int or array-like of ints.")
156159

157160
num_sys = len(dim)
158161
prod_dim = np.prod(dim)

toqito/matrix_ops/tests/test_partial_trace.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -684,15 +684,15 @@ def test_dim_list_branches(dim_value):
684684
# dim=None and non-perfect-square
685685
(np.ones((6, 6)), [0], None, "Cannot infer subsystem dimensions directly"),
686686
687-
# Invalid dim type
688-
(np.eye(4), [0], (2, 2), None),
687+
# dim not 1D (2D array)
688+
(np.eye(4), [0], np.array([[2, 2]]), "1D"),
689+
690+
# completely invalid dim type
691+
(np.eye(4), [0], "invalid_dim", "int or array-like"),
689692
],
690693
)
691694
def test_partial_trace_invalid_inputs(input_mat, sys_value, dim_value, error_msg):
692695
"""Test various invalid parameter combinations."""
693-
if error_msg:
694-
with pytest.raises(ValueError, match=error_msg):
695-
partial_trace(input_mat, sys_value, dim_value)
696-
else:
697-
with pytest.raises(ValueError):
698-
partial_trace(input_mat, sys_value, dim_value)
696+
with pytest.raises(ValueError, match=error_msg):
697+
partial_trace(input_mat, sys_value, dim_value)
698+

0 commit comments

Comments
 (0)