Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify callbacks #4289

Merged
merged 1 commit into from
Aug 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,119 +58,118 @@ def get_registered_actions(self, hook=None):
else:
return self._callbacks

@staticmethod
def run_callbacks(register, *args, **kwargs):
def run_callbacks(self, hook, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in register:
for logger in self._callbacks[hook]:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)

def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)

def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)

def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
self.run_callbacks('on_train_start', *args, **kwargs)

def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
self.run_callbacks('on_train_epoch_start', *args, **kwargs)

def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
self.run_callbacks('on_train_batch_start', *args, **kwargs)

def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
self.run_callbacks('optimizer_step', *args, **kwargs)

def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
self.run_callbacks('on_before_zero_grad', *args, **kwargs)

def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
self.run_callbacks('on_train_batch_end', *args, **kwargs)

def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
self.run_callbacks('on_train_epoch_end', *args, **kwargs)

def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
self.run_callbacks('on_val_start', *args, **kwargs)

def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
self.run_callbacks('on_val_batch_start', *args, **kwargs)

def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
self.run_callbacks('on_val_image_end', *args, **kwargs)

def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
self.run_callbacks('on_val_batch_end', *args, **kwargs)

def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
self.run_callbacks('on_val_end', *args, **kwargs)

def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)

def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
"""
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
self.run_callbacks('on_model_save', *args, **kwargs)

def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
self.run_callbacks('on_train_end', *args, **kwargs)

def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
self.run_callbacks('teardown', *args, **kwargs)