Skip to content

Commit

Permalink
Merge pull request #2457 from flairNLP/trainer-details
Browse files Browse the repository at this point in the history
Automatic "Model Cards" for Flair models, and default ability to resume training
  • Loading branch information
alanakbik authored Oct 1, 2021
2 parents a72e263 + 7343c56 commit d58128a
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 187 deletions.
11 changes: 3 additions & 8 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,9 @@ def _init_model_with_state_dict(state):

rnn_type = "LSTM" if "rnn_type" not in state.keys() else state["rnn_type"]
use_dropout = 0.0 if "use_dropout" not in state.keys() else state["use_dropout"]
use_word_dropout = (
0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"]
)
use_locked_dropout = (
0.0
if "use_locked_dropout" not in state.keys()
else state["use_locked_dropout"]
)
use_word_dropout = 0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"]
use_locked_dropout = 0.0 if "use_locked_dropout" not in state.keys() else state["use_locked_dropout"]

train_initial_hidden_state = (
False
if "train_initial_hidden_state" not in state.keys()
Expand Down
59 changes: 57 additions & 2 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import itertools
import logging
import warnings
Expand Down Expand Up @@ -82,8 +81,37 @@ def save(self, model_file: Union[str, Path]):
"""
model_state = self._get_state_dict()

# in Flair <0.9.1, optimizer and scheduler used to train model are not saved
optimizer = scheduler = None

# write out a "model card" if one is set
if hasattr(self, 'model_card'):

# special handling for optimizer: remember optimizer class and state dictionary
if 'training_parameters' in self.model_card:
training_parameters = self.model_card['training_parameters']

if 'optimizer' in training_parameters:
optimizer = training_parameters['optimizer']
training_parameters['optimizer_state_dict'] = optimizer.state_dict()
training_parameters['optimizer'] = optimizer.__class__

if 'scheduler' in training_parameters:
scheduler = training_parameters['scheduler']
training_parameters['scheduler_state_dict'] = scheduler.state_dict()
training_parameters['scheduler'] = scheduler.__class__

model_state['model_card'] = self.model_card

# save model
torch.save(model_state, str(model_file), pickle_protocol=4)

# restore optimizer and scheduler to model card if set
if optimizer:
self.model_card['training_parameters']['optimizer'] = optimizer
if scheduler:
self.model_card['training_parameters']['scheduler'] = scheduler

@classmethod
def load(cls, model: Union[str, Path]):
"""
Expand All @@ -102,11 +130,38 @@ def load(cls, model: Union[str, Path]):

model = cls._init_model_with_state_dict(state)

if 'model_card' in state:
model.model_card = state['model_card']

model.eval()
model.to(flair.device)

return model

def print_model_card(self):
if hasattr(self, 'model_card'):
param_out = "\n------------------------------------\n"
param_out += "--------- Flair Model Card ---------\n"
param_out += "------------------------------------\n"
param_out += "- this Flair model was trained with:\n"
param_out += f"-- Flair version {self.model_card['flair_version']}\n"
param_out += f"-- PyTorch version {self.model_card['pytorch_version']}\n"
if 'transformers_version' in self.model_card:
param_out += f"-- Transformers version {self.model_card['transformers_version']}\n"
param_out += "------------------------------------\n"

param_out += "------- Training Parameters: -------\n"
param_out += "------------------------------------\n"
training_params = '\n'.join(f'-- {param} = {self.model_card["training_parameters"][param]}'
for param in self.model_card['training_parameters'])
param_out += training_params + "\n"
param_out += "------------------------------------\n"

log.info(param_out)
else:
log.info(
"This model has no model card (likely because it is not yet trained or was trained with Flair version < 0.9.1)")


class Classifier(Model):
"""Abstract base class for all Flair models that do classification, both single- and multi-label.
Expand Down Expand Up @@ -175,7 +230,7 @@ def evaluate(

for gold_label in datapoint.get_labels(gold_label_type):
representation = str(sentence_id) + ': ' + gold_label.identifier

value = gold_label.value
if gold_label_dictionary and gold_label_dictionary.get_idx_for_item(value) == 0:
value = '<unk>'
Expand Down
Loading

0 comments on commit d58128a

Please sign in to comment.