Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic "Model Cards" for Flair models, and default ability to resume training #2457

Merged
merged 9 commits into from
Oct 1, 2021
11 changes: 3 additions & 8 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,9 @@ def _init_model_with_state_dict(state):

rnn_type = "LSTM" if "rnn_type" not in state.keys() else state["rnn_type"]
use_dropout = 0.0 if "use_dropout" not in state.keys() else state["use_dropout"]
use_word_dropout = (
0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"]
)
use_locked_dropout = (
0.0
if "use_locked_dropout" not in state.keys()
else state["use_locked_dropout"]
)
use_word_dropout = 0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"]
use_locked_dropout = 0.0 if "use_locked_dropout" not in state.keys() else state["use_locked_dropout"]

train_initial_hidden_state = (
False
if "train_initial_hidden_state" not in state.keys()
Expand Down
59 changes: 57 additions & 2 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import itertools
import logging
import warnings
Expand Down Expand Up @@ -82,8 +81,37 @@ def save(self, model_file: Union[str, Path]):
"""
model_state = self._get_state_dict()

# in Flair <0.9.1, optimizer and scheduler used to train model are not saved
optimizer = scheduler = None

# write out a "model card" if one is set
if hasattr(self, 'model_card'):

# special handling for optimizer: remember optimizer class and state dictionary
if 'training_parameters' in self.model_card:
training_parameters = self.model_card['training_parameters']

if 'optimizer' in training_parameters:
optimizer = training_parameters['optimizer']
training_parameters['optimizer_state_dict'] = optimizer.state_dict()
training_parameters['optimizer'] = optimizer.__class__

if 'scheduler' in training_parameters:
scheduler = training_parameters['scheduler']
training_parameters['scheduler_state_dict'] = scheduler.state_dict()
training_parameters['scheduler'] = scheduler.__class__

model_state['model_card'] = self.model_card

# save model
torch.save(model_state, str(model_file), pickle_protocol=4)

# restore optimizer and scheduler to model card if set
if optimizer:
self.model_card['training_parameters']['optimizer'] = optimizer
if scheduler:
self.model_card['training_parameters']['scheduler'] = scheduler

@classmethod
def load(cls, model: Union[str, Path]):
"""
Expand All @@ -102,11 +130,38 @@ def load(cls, model: Union[str, Path]):

model = cls._init_model_with_state_dict(state)

if 'model_card' in state:
model.model_card = state['model_card']

model.eval()
model.to(flair.device)

return model

def print_model_card(self):
if hasattr(self, 'model_card'):
param_out = "\n------------------------------------\n"
param_out += "--------- Flair Model Card ---------\n"
param_out += "------------------------------------\n"
param_out += "- this Flair model was trained with:\n"
param_out += f"-- Flair version {self.model_card['flair_version']}\n"
param_out += f"-- PyTorch version {self.model_card['pytorch_version']}\n"
if 'transformers_version' in self.model_card:
param_out += f"-- Transformers version {self.model_card['transformers_version']}\n"
param_out += "------------------------------------\n"

param_out += "------- Training Parameters: -------\n"
param_out += "------------------------------------\n"
training_params = '\n'.join(f'-- {param} = {self.model_card["training_parameters"][param]}'
for param in self.model_card['training_parameters'])
param_out += training_params + "\n"
param_out += "------------------------------------\n"

log.info(param_out)
else:
log.info(
"This model has no model card (likely because it is not yet trained or was trained with Flair version < 0.9.1)")


class Classifier(Model):
"""Abstract base class for all Flair models that do classification, both single- and multi-label.
Expand Down Expand Up @@ -175,7 +230,7 @@ def evaluate(

for gold_label in datapoint.get_labels(gold_label_type):
representation = str(sentence_id) + ': ' + gold_label.identifier

value = gold_label.value
if gold_label_dictionary and gold_label_dictionary.get_idx_for_item(value) == 0:
value = '<unk>'
Expand Down
Loading