Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 13 additions & 25 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,14 @@ def create_reference_model(
parameter_names = [n for n, _ in model.named_parameters()]
ref_model = deepcopy(model)

for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
ref_model = ref_model.eval()

# if no layers are shared, return copy of model
if num_shared_layers is None:
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
return ref_model.eval()
return ref_model

# identify layer name pattern
if pattern is not None:
Expand All @@ -638,39 +640,25 @@ def create_reference_model(
if any(pattern_candidate in name for name in parameter_names):
pattern = pattern_candidate
break
if pattern is None:
raise ValueError("Layer pattern could not be matched.")

if pattern is None:
raise ValueError("Layer pattern could not be matched.")

# divide parameters in shared and unshared parameter lists
shared_param_list = []
unshared_param_list = []

# freeze the shared layers
shared_parameter = True
unshared_param_list = []
for name, _param in model.named_parameters():
if pattern in name:
shared_parameter = False
if shared_parameter:
shared_param_list.append(name)
param = model.get_parameter(name)
param.requires_grad = False
else:
unshared_param_list.append(name)

# create reference of the original parameter if they are shared
for param_name in shared_param_list:
param = model.get_parameter(param_name)
param.requires_grad = False

_ref_param = ref_model.get_parameter(param_name)

# for all other parameters just make sure they don't use gradients
for param_name in unshared_param_list:
param = ref_model.get_parameter(param_name)
param.requires_grad = False

if pattern is not None and len(unshared_param_list) == 0:
logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.")

return ref_model.eval()
return ref_model


class GeometricMixtureWrapper(GenerationMixin):
Expand Down
Loading