diff --git a/src/qibo/models/circuit.py b/src/qibo/models/circuit.py index 7c5d1ea591..dbbff01428 100644 --- a/src/qibo/models/circuit.py +++ b/src/qibo/models/circuit.py @@ -749,18 +749,21 @@ def gates_of_type(self, gate: Union[str, type]) -> List[Tuple[int, gates.Gate]]: return [(i, g) for i, g in enumerate(self.queue) if isinstance(g, gate)] - def _set_parameters_list(self, parameters, n): + def _set_parameters_list(self, parameters, n, include_not_trainable: bool = False): """Helper method for ``set_parameters`` when a list is given. Also works if ``parameters`` is ``np.ndarray`` or ``tf.Tensor``. """ - if n == len(self.trainable_gates): - for i, gate in enumerate(self.trainable_gates): + _gates = ( + self.parametrized_gates if include_not_trainable else self.trainable_gates + ) + if n == len(_gates): + for i, gate in enumerate(_gates): gate.parameters = parameters[i] - elif n == self.trainable_gates.nparams: + elif n == _gates.nparams: parameters = list(parameters) k = 0 - for i, gate in enumerate(self.trainable_gates): + for i, gate in enumerate(_gates): if gate.nparams == 1: gate.parameters = parameters[i + k] else: @@ -770,10 +773,10 @@ def _set_parameters_list(self, parameters, n): raise_error( ValueError, f"Given list of parameters has length {n} while " - + f"the circuit contains {len(self.trainable_gates)} parametrized gates.", + + f"the circuit contains {len(_gates)} parametrized gates.", ) - def set_parameters(self, parameters): + def set_parameters(self, parameters, include_not_trainable: bool = False): """Updates the parameters of the circuit's parametrized gates. For more information on how to use this method we refer to the @@ -819,8 +822,11 @@ def set_parameters(self, parameters): params = [0.123, 0.456, 0.789, 0.321] circuit.set_parameters(params) """ + _gates = ( + self.parametrized_gates if include_not_trainable else self.trainable_gates + ) if isinstance(parameters, dict): - diff = set(parameters.keys()) - self.trainable_gates.set + diff = set(parameters.keys()) - _gates.set if diff: raise_error( KeyError, @@ -836,7 +842,9 @@ def set_parameters(self, parameters): nparams = int(parameters.shape[0]) except AttributeError: nparams = len(parameters) - self._set_parameters_list(parameters, nparams) + self._set_parameters_list( + parameters, nparams, include_not_trainable=include_not_trainable + ) else: raise_error(TypeError, f"Invalid type of parameters {type(parameters)}.")