-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Sparse-Gluon] embedding with sparse grad #10924
Changes from all commits
f9e160c
200dc6f
9c84de8
943e04e
6a64ac5
e1c6e38
9c1aba3
605bba6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,16 +110,30 @@ def _init_optimizer(self, optimizer, optimizer_params): | |
for _ in self._contexts] | ||
|
||
def _init_kvstore(self): | ||
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} | ||
arg_arrays = {} | ||
contains_sparse = False | ||
for param in self._params: | ||
arg_arrays[param.name] = param.data(self._contexts[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you try to do step on parameter with deferred initialization? what message did you get? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I tried the language model example. param.data() was called in the previous implementation (line 113) |
||
if param._grad_stype != 'default': | ||
contains_sparse = True | ||
# update_on_kvstore is set to False by the user | ||
if self._update_on_kvstore is False: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if not self._update_on_kvstore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if self._update_on_kvstore is None, I don't need to throw the err There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
raise RuntimeError("Cannot set update_on_kvstore to False when sparse " | ||
"gradients and/or sparse weights are present for " | ||
"Parameter %s." % param.name) | ||
kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts), | ||
arg_arrays) | ||
update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore is not None \ | ||
else update_on_kvstore | ||
if kvstore: | ||
if self._compression_params: | ||
kvstore.set_gradient_compression(self._compression_params) | ||
if 'dist' in kvstore.type: | ||
update_on_kvstore = False | ||
# kv.pull(row_sparse_grad) is not supported | ||
if contains_sparse: | ||
update_on_kvstore = True | ||
else: | ||
if 'dist' in kvstore.type: | ||
update_on_kvstore = False | ||
if update_on_kvstore: | ||
kvstore.set_optimizer(self._optimizer) | ||
# optimizer preferably needs to be set before init for multiprecision | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
documentation for argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added