[MXNET-374] handle row_sparse weight in parameter and trainer#11001
[MXNET-374] handle row_sparse weight in parameter and trainer#11001piiswrong merged 23 commits intoapache:masterfrom
Conversation
| if stype != 'default': | ||
| raise ValueError("Cannot create a HybridBlock with Parameter '%s' " \ | ||
| "because its storage type is %s. Please consider " \ | ||
| "using a SparseBlock instead."%(param.name, stype)) |
There was a problem hiding this comment.
PR for sparse block will be created separately after this one is merged.
There was a problem hiding this comment.
"please consider using" -> "please use"
| def test_sparse_parameter(): | ||
| p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse') | ||
| p = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse') | ||
| p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) |
There was a problem hiding this comment.
Seems like constraining the contexts to cpu is causing test failures on GPU, is this a necessary thing?
| "grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \ | ||
| " but got '%s'" % (name, grad_stype) | ||
| # sparse related storage type information | ||
| valid_stypes = ['default', 'row_sparse', 'csr'] |
There was a problem hiding this comment.
Only has 3 elements. I don't think this makes any real difference
| """ Set the trainer this parameter is associated with. """ | ||
| if self._trainer and self._trainer is not trainer: | ||
| raise RuntimeError( | ||
| "Failed to set the trainer for Parameter '%s' to %s because it was set to %s. " \ |
There was a problem hiding this comment.
How can user detach a parameter's association with a trainer without exiting python?
There was a problem hiding this comment.
Updated. Users can just call _set_trainer(None). I don't think this will be used by common users, hence it remains private
| """ | ||
| def __init__(self, prefix=None, params=None): | ||
| # check if any parameter is row_sparse | ||
| if isinstance(params, ParameterDict): |
There was a problem hiding this comment.
This check shouldn't be done here.
Parameters are only added to the current block when self.params.get is called.
There was a problem hiding this comment.
Removed. Will the checks in param.list_data() and param.data() be sufficient?
| raise RuntimeError( | ||
| "Failed to set the trainer for Parameter '%s' to %s because it was set to %s. " \ | ||
| "More than one trainers for a single Parameter is not supported." %( | ||
| self.name, str(trainer), str(self._trainer))) |
There was a problem hiding this comment.
what does str(trainer) show? It's likely not meaningful to users
There was a problem hiding this comment.
This is a breaking change.
Suppose users want to use sgd to train 10 epochs and then switch to ADAM, this would prevent that.
There was a problem hiding this comment.
Now only throws exception for rowsparse param
| """ Get row_sparse data from row_sparse parameters based on row_id. """ | ||
| # get row sparse params based on row ids | ||
| if not isinstance(row_id, ndarray.NDArray): | ||
| raise TypeError("Cannot get 'row_sparse' Parameter %s with %s type. " |
There was a problem hiding this comment.
"row_id must have NDArray type, but %s is given"
| "NDArray type is expected." % (self.name, type(row_id))) | ||
| if not self._trainer: | ||
| raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \ | ||
| "Trainer is created with it."%self.name) |
There was a problem hiding this comment.
What if user want to train with single device?
There was a problem hiding this comment.
For single device, we will encourage the user to use normal hybrid blocks with sparse_grad=True. There's no need to use rowsparse weight.
Even if the user choose to use rowsparse weight, a kvstore is created for the rowsparse param and the code still works.
| """(Re)initializes by loading from data.""" | ||
| if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore: | ||
| raise RuntimeError("Cannot (Re)initialize Parameter '%s' when its Trainer " \ | ||
| "already initialized the parameter on KVStore."%(self.name)) |
There was a problem hiding this comment.
message is cryptic. The reason is multi device training and update_on_kvstore is true.
error message should describe the reason and suggest a solution
There was a problem hiding this comment.
Updated message.
| NDArray on ctx | ||
| """ | ||
| if self._stype != 'default': | ||
| raise ValueError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \ |
There was a problem hiding this comment.
These should be UserError?
There was a problem hiding this comment.
Maybe I should change to RuntimeError? There's UserWarning but I am not aware of UserError
| self._param2idx[param.name] = i | ||
| self._params.append(param) | ||
| self._params_to_init.append(param) | ||
| param._set_trainer(self) |
There was a problem hiding this comment.
do we need to set_trainer when stype='default' and update_on_kvstore=False?
| for _ in self._contexts] | ||
|
|
||
| def _init_params(self): | ||
| """ Initialize parameters in the KVStore. Parameters whose |
| "when KVStore is not initialized." | ||
| params_to_init = [] | ||
| if self._kvstore: | ||
| params = [param for param in self._params_to_init \ |
There was a problem hiding this comment.
better to use for loop and if/else here
| """ | ||
| if not self._kv_initialized: | ||
| self._init_kvstore() | ||
| if self._params_to_init: |
There was a problem hiding this comment.
I don't quite understand this. If there are uninitialized parameters, wouldn't step fail?
There was a problem hiding this comment.
I moved the logics of kv.init(param) from _init_kvstore to _init_params. _params_to_init refers to params that are not initialized on kvstore.
857dfd5 to
6038fe9
Compare
…#11001) * + rsp parameter * draft * Fix optimizer pickle * refactor and document * add test for save load with cast_stype * refactor trainer tests * add test * add back test * raise error for load params * add comment * remove print * fix doc * CR comments * CR comments * change error * remove cast stype * fix test * add reset kvstore to trainer * lint * add test to CI * add more checks
…#11001) * + rsp parameter * draft * Fix optimizer pickle * refactor and document * add test for save load with cast_stype * refactor trainer tests * add test * add back test * raise error for load params * add comment * remove print * fix doc * CR comments * CR comments * change error * remove cast stype * fix test * add reset kvstore to trainer * lint * add test to CI * add more checks
Description
@piiswrong @szha @ZiyueHuang @haojin2 @safrooze please review.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments