Extracting the discussion from #8716
Summary of the issue: prediction_step() has a model argument which is a normal model with n_gpu < 2, and a wrapped DataParallel model with n_gpu > 1. So the API suffers from ambiguity here.
The user has to really use self.model to be able to call methods like model.config() or model.generate(), which can't be called on the wrapped model. But it's very likely they will use model instead since it'll act like self.model unless under multi_gpu. And why do we even have that model argument then?
Possible solutions discussed:
-
monkeypatch torch.nn.DataParallel to expand its API to support all the methods of the original model transparently by installing a catch all __getattr__ and remap all the failed method look ups to delegate to self.module.
-
not to call the function argument model anymore, since it isn't under multi gpu, but is something else.
-
remove the model argument completely + document to always use self.model - currently in seq2seq_trainer.py once we switch to self.model, prediction_step() no longer needs model as an argument (but is it always the case?)
-
pass self.model as the model arg, and making the wrapped model available via self.wrapped_model if the user needs it.
Summary of discussion around proposed solutions:
-
too magical
-
proposed calling it wrapped_model, but it's just as confusing since most of the time it's not.
-
need to check whether wrapped model is every needed inside user functions.
-
was not discussed yet
@sgugger, @LysandreJik
Extracting the discussion from #8716
Summary of the issue:
prediction_step()has amodelargument which is a normal model with n_gpu < 2, and a wrapped DataParallel model with n_gpu > 1. So the API suffers from ambiguity here.The user has to really use
self.modelto be able to call methods likemodel.config()ormodel.generate(), which can't be called on the wrapped model. But it's very likely they will usemodelinstead since it'll act likeself.modelunless under multi_gpu. And why do we even have thatmodelargument then?Possible solutions discussed:
monkeypatch
torch.nn.DataParallelto expand its API to support all the methods of the original model transparently by installing a catch all__getattr__and remap all the failed method look ups to delegate toself.module.not to call the function argument
modelanymore, since it isn't under multi gpu, but is something else.remove the
modelargument completely + document to always useself.model- currently inseq2seq_trainer.pyonce we switch toself.model,prediction_step()no longer needsmodelas an argument (but is it always the case?)pass
self.modelas themodelarg, and making the wrapped model available viaself.wrapped_modelif the user needs it.Summary of discussion around proposed solutions:
too magical
proposed calling it
wrapped_model, but it's just as confusing since most of the time it's not.need to check whether wrapped model is every needed inside user functions.
was not discussed yet
@sgugger, @LysandreJik