Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/qibo/models/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,18 +749,21 @@ def gates_of_type(self, gate: Union[str, type]) -> List[Tuple[int, gates.Gate]]:

return [(i, g) for i, g in enumerate(self.queue) if isinstance(g, gate)]

def _set_parameters_list(self, parameters, n):
def _set_parameters_list(self, parameters, n, include_not_trainable: bool = False):
"""Helper method for ``set_parameters`` when a list is given.

Also works if ``parameters`` is ``np.ndarray`` or ``tf.Tensor``.
"""
if n == len(self.trainable_gates):
for i, gate in enumerate(self.trainable_gates):
_gates = (
self.parametrized_gates if include_not_trainable else self.trainable_gates
)
if n == len(_gates):
for i, gate in enumerate(_gates):
gate.parameters = parameters[i]
elif n == self.trainable_gates.nparams:
elif n == _gates.nparams:
parameters = list(parameters)
k = 0
for i, gate in enumerate(self.trainable_gates):
for i, gate in enumerate(_gates):
if gate.nparams == 1:
gate.parameters = parameters[i + k]
else:
Expand All @@ -770,10 +773,10 @@ def _set_parameters_list(self, parameters, n):
raise_error(
ValueError,
f"Given list of parameters has length {n} while "
+ f"the circuit contains {len(self.trainable_gates)} parametrized gates.",
+ f"the circuit contains {len(_gates)} parametrized gates.",
)

def set_parameters(self, parameters):
def set_parameters(self, parameters, include_not_trainable: bool = False):
"""Updates the parameters of the circuit's parametrized gates.

For more information on how to use this method we refer to the
Expand Down Expand Up @@ -819,8 +822,11 @@ def set_parameters(self, parameters):
params = [0.123, 0.456, 0.789, 0.321]
circuit.set_parameters(params)
"""
_gates = (
self.parametrized_gates if include_not_trainable else self.trainable_gates
)
if isinstance(parameters, dict):
diff = set(parameters.keys()) - self.trainable_gates.set
diff = set(parameters.keys()) - _gates.set
if diff:
raise_error(
KeyError,
Expand All @@ -836,7 +842,9 @@ def set_parameters(self, parameters):
nparams = int(parameters.shape[0])
except AttributeError:
nparams = len(parameters)
self._set_parameters_list(parameters, nparams)
self._set_parameters_list(
parameters, nparams, include_not_trainable=include_not_trainable
)
else:
raise_error(TypeError, f"Invalid type of parameters {type(parameters)}.")

Expand Down
Loading