diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index fcf21d5cf..65375b284 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -254,14 +254,9 @@ def _init_model_with_state_dict(state): rnn_type = "LSTM" if "rnn_type" not in state.keys() else state["rnn_type"] use_dropout = 0.0 if "use_dropout" not in state.keys() else state["use_dropout"] - use_word_dropout = ( - 0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"] - ) - use_locked_dropout = ( - 0.0 - if "use_locked_dropout" not in state.keys() - else state["use_locked_dropout"] - ) + use_word_dropout = 0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"] + use_locked_dropout = 0.0 if "use_locked_dropout" not in state.keys() else state["use_locked_dropout"] + train_initial_hidden_state = ( False if "train_initial_hidden_state" not in state.keys() diff --git a/flair/nn/model.py b/flair/nn/model.py index 1a95044d8..007dde46d 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -1,4 +1,3 @@ -import copy import itertools import logging import warnings @@ -82,8 +81,37 @@ def save(self, model_file: Union[str, Path]): """ model_state = self._get_state_dict() + # in Flair <0.9.1, optimizer and scheduler used to train model are not saved + optimizer = scheduler = None + + # write out a "model card" if one is set + if hasattr(self, 'model_card'): + + # special handling for optimizer: remember optimizer class and state dictionary + if 'training_parameters' in self.model_card: + training_parameters = self.model_card['training_parameters'] + + if 'optimizer' in training_parameters: + optimizer = training_parameters['optimizer'] + training_parameters['optimizer_state_dict'] = optimizer.state_dict() + training_parameters['optimizer'] = optimizer.__class__ + + if 'scheduler' in training_parameters: + scheduler = training_parameters['scheduler'] + training_parameters['scheduler_state_dict'] = scheduler.state_dict() + training_parameters['scheduler'] = scheduler.__class__ + + model_state['model_card'] = self.model_card + + # save model torch.save(model_state, str(model_file), pickle_protocol=4) + # restore optimizer and scheduler to model card if set + if optimizer: + self.model_card['training_parameters']['optimizer'] = optimizer + if scheduler: + self.model_card['training_parameters']['scheduler'] = scheduler + @classmethod def load(cls, model: Union[str, Path]): """ @@ -102,11 +130,38 @@ def load(cls, model: Union[str, Path]): model = cls._init_model_with_state_dict(state) + if 'model_card' in state: + model.model_card = state['model_card'] + model.eval() model.to(flair.device) return model + def print_model_card(self): + if hasattr(self, 'model_card'): + param_out = "\n------------------------------------\n" + param_out += "--------- Flair Model Card ---------\n" + param_out += "------------------------------------\n" + param_out += "- this Flair model was trained with:\n" + param_out += f"-- Flair version {self.model_card['flair_version']}\n" + param_out += f"-- PyTorch version {self.model_card['pytorch_version']}\n" + if 'transformers_version' in self.model_card: + param_out += f"-- Transformers version {self.model_card['transformers_version']}\n" + param_out += "------------------------------------\n" + + param_out += "------- Training Parameters: -------\n" + param_out += "------------------------------------\n" + training_params = '\n'.join(f'-- {param} = {self.model_card["training_parameters"][param]}' + for param in self.model_card['training_parameters']) + param_out += training_params + "\n" + param_out += "------------------------------------\n" + + log.info(param_out) + else: + log.info( + "This model has no model card (likely because it is not yet trained or was trained with Flair version < 0.9.1)") + class Classifier(Model): """Abstract base class for all Flair models that do classification, both single- and multi-label. @@ -175,7 +230,7 @@ def evaluate( for gold_label in datapoint.get_labels(gold_label_type): representation = str(sentence_id) + ': ' + gold_label.identifier - + value = gold_label.value if gold_label_dictionary and gold_label_dictionary.get_idx_for_item(value) == 0: value = '' diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index c97821cb4..34f0e00e9 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -6,6 +6,7 @@ import sys import time import warnings +from inspect import signature from pathlib import Path from typing import Union, Tuple, Optional @@ -13,6 +14,8 @@ from torch.optim.sgd import SGD from torch.utils.data.dataset import ConcatDataset +from flair.nn import Model + try: from apex import amp except ImportError: @@ -54,44 +57,17 @@ def __init__( self.corpus: Corpus = corpus @staticmethod - def check_for_and_delete_previous_best_models(base_path, save_checkpoint): + def check_for_and_delete_previous_best_models(base_path): all_best_model_names = [filename for filename in os.listdir(base_path) if filename.startswith("best-model")] if len(all_best_model_names) != 0: warnings.warn( - "There should be no best model saved at epoch 1 except there is a model from previous trainings in your training folder. All previous best models will be deleted.") + "There should be no best model saved at epoch 1 except there is a model from previous trainings" + " in your training folder. All previous best models will be deleted.") for single_model in all_best_model_names: previous_best_path = os.path.join(base_path, single_model) if os.path.exists(previous_best_path): os.remove(previous_best_path) - if save_checkpoint: - best_checkpoint_path = previous_best_path.replace("model", "checkpoint") - if os.path.exists(best_checkpoint_path): - os.remove(best_checkpoint_path) - - def fine_tune(self, - base_path: Union[Path, str], - learning_rate: float = 5e-5, - max_epochs: int = 10, - optimizer=torch.optim.AdamW, - scheduler=LinearSchedulerWithWarmup, - warmup_fraction: float = 0.1, - mini_batch_size: int = 4, - embeddings_storage_mode: str = 'none', - **trainer_args, - ): - - return self.train( - base_path=base_path, - learning_rate=learning_rate, - max_epochs=max_epochs, - optimizer=optimizer, - scheduler=scheduler, - warmup_fraction=warmup_fraction, - mini_batch_size=mini_batch_size, - embeddings_storage_mode=embeddings_storage_mode, - **trainer_args, - ) def train( self, @@ -100,17 +76,19 @@ def train( mini_batch_size: int = 32, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 100, + train_with_dev: bool = False, + train_with_test: bool = False, + monitor_train: bool = False, + monitor_test: bool = False, + main_evaluation_metric: Tuple[str, str] = ("micro avg", 'f1-score'), scheduler=AnnealOnPlateau, - cycle_momentum: bool = False, anneal_factor: float = 0.5, patience: int = 3, - initial_extra_patience: int = 0, min_learning_rate: float = 0.0001, + initial_extra_patience: int = 0, + optimizer: torch.optim.Optimizer = SGD, + cycle_momentum: bool = False, warmup_fraction: float = 0.1, - train_with_dev: bool = False, - train_with_test: bool = False, - monitor_train: bool = False, - monitor_test: bool = False, embeddings_storage_mode: str = "cpu", checkpoint: bool = False, save_final_model: bool = True, @@ -128,19 +106,18 @@ def train( eval_on_train_fraction: float = 0.0, eval_on_train_shuffle: bool = False, save_model_each_k_epochs: int = 0, - main_evaluation_metric: Tuple[str, str] = ("micro avg", 'f1-score'), tensorboard_comment: str = '', - save_best_checkpoints: bool = False, use_swa: bool = False, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, create_file_logs: bool = True, create_loss_file: bool = True, - optimizer: torch.optim.Optimizer = SGD, epoch: int = 0, use_tensorboard: bool = False, tensorboard_log_dir=None, metrics_for_tensorboard=[], + optimizer_state_dict: Optional = None, + scheduler_state_dict: Optional = None, **kwargs, ) -> dict: """ @@ -151,18 +128,19 @@ def train( :param mini_batch_chunk_size: If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes :param max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed. :param scheduler: The learning rate scheduler to use + :param checkpoint: If True, a full checkpoint is saved at end of each epoch :param cycle_momentum: If scheduler is OneCycleLR, whether the scheduler should cycle also the momentum :param anneal_factor: The factor by which the learning rate is annealed :param patience: Patience is the number of epochs with no improvement the Trainer waits until annealing the learning rate :param min_learning_rate: If the learning rate falls below this threshold, training terminates :param warmup_fraction: Fraction of warmup steps if the scheduler is LinearSchedulerWithWarmup - :param train_with_dev: If True, training is performed using both train+dev data + :param train_with_dev: If True, the data from dev split is added to the training data + :param train_with_test: If True, the data from test split is added to the training data :param monitor_train: If True, training data is evaluated at end of each epoch :param monitor_test: If True, test data is evaluated at end of each epoch :param embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed), 'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU) - :param checkpoint: If True, a full checkpoint is saved at end of each epoch :param save_final_model: If True, final model is saved :param anneal_with_restarts: If True, the last best model is restored when annealing the learning rate :param shuffle: If True, data is shuffled during training @@ -177,10 +155,8 @@ def train( and kept fixed during training, otherwise it's sampled at beginning of each epoch :param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will be saved each 5 epochs. Default is 0 which means no model saving. - :param save_model_epoch_step: Each save_model_epoch_step'th epoch the thus far trained model will be saved - :param classification_main_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model + :param main_evaluation_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model :param tensorboard_comment: Comment to use for tensorboard logging - :param save_best_checkpoints: If True, in addition to saving the best model also the corresponding checkpoint is saved :param create_file_logs: If True, the logs will also be stored in a file 'training.log' in the model folder :param create_loss_file: If True, the loss will be writen to a file 'loss.tsv' in the model folder :param optimizer: The optimizer to use (typically SGD or Adam) @@ -192,6 +168,26 @@ def train( :return: """ + # create a model card for this model with Flair and PyTorch version + model_card = {'flair_version': flair.__version__, 'pytorch_version': torch.__version__} + + # also record Transformers version if library is loaded + try: + import transformers + model_card['transformers_version'] = transformers.__version__ + except: + pass + + # remember all parameters used in train() call + local_variables = locals() + training_parameters = {} + for parameter in signature(self.train).parameters: + training_parameters[parameter] = local_variables[parameter] + model_card['training_parameters'] = training_parameters + + # add model card to model + self.model.model_card = model_card + if use_tensorboard: try: from torch.utils.tensorboard import SummaryWriter @@ -203,9 +199,7 @@ def train( except: log_line(log) - log.warning( - "ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!" - ) + log.warning("ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!") log_line(log) use_tensorboard = False pass @@ -261,7 +255,7 @@ def train( log.warning(f'WARNING: Specified class weights will not take effect when using CRF') # check for previously saved best models in the current training folder and delete them - self.check_for_and_delete_previous_best_models(base_path, save_best_checkpoints) + self.check_for_and_delete_previous_best_models(base_path) # determine what splits (train, dev, test) to evaluate and log log_train = True if monitor_train else False @@ -270,27 +264,22 @@ def train( log_train_part = True if (eval_on_train_fraction == "dev" or eval_on_train_fraction > 0.0) else False if log_train_part: - train_part_size = ( - len(self.corpus.dev) - if eval_on_train_fraction == "dev" + train_part_size = len(self.corpus.dev) if eval_on_train_fraction == "dev" \ else int(len(self.corpus.train) * eval_on_train_fraction) - ) + assert train_part_size > 0 if not eval_on_train_shuffle: train_part_indices = list(range(train_part_size)) - train_part = torch.utils.data.dataset.Subset( - self.corpus.train, train_part_indices - ) + train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices) - if create_loss_file: - # prepare loss logging file and set up header - loss_txt = init_output_file(base_path, "loss.tsv") - else: - loss_txt = None + # prepare loss logging file and set up header + loss_txt = init_output_file(base_path, "loss.tsv") if create_loss_file else None weight_extractor = WeightExtractor(base_path) - optimizer: torch.optim.Optimizer = optimizer(self.model.parameters(), lr=learning_rate, **kwargs) + # if optimizer class is passed, instantiate: + if inspect.isclass(optimizer): + optimizer: torch.optim.Optimizer = optimizer(self.model.parameters(), lr=learning_rate, **kwargs) if use_swa: import torchcontrib @@ -301,6 +290,10 @@ def train( self.model, optimizer, opt_level=amp_opt_level ) + # load existing optimizer state dictionary if it exists + if optimizer_state_dict: + optimizer.load_state_dict(optimizer_state_dict) + # minimize training loss if training with dev data, else maximize dev score anneal_mode = "min" if train_with_dev or anneal_against_dev_loss else "max" best_validation_score = 100000000000 if train_with_dev or anneal_against_dev_loss else 0. @@ -309,33 +302,43 @@ def train( if train_with_dev: dataset_size += len(self.corpus.dev) - if scheduler == OneCycleLR: - lr_scheduler = OneCycleLR(optimizer, - max_lr=learning_rate, - steps_per_epoch=dataset_size // mini_batch_size + 1, - epochs=max_epochs - epoch, - # if we load a checkpoint, we have already trained for epoch - pct_start=0.0, - cycle_momentum=cycle_momentum) - elif scheduler == LinearSchedulerWithWarmup: - steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size - num_train_steps = int(steps_per_epoch * max_epochs) - num_warmup_steps = int(num_train_steps * warmup_fraction) - - lr_scheduler = LinearSchedulerWithWarmup(optimizer, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps) - else: - lr_scheduler = scheduler( - optimizer, - factor=anneal_factor, - patience=patience, - initial_extra_patience=initial_extra_patience, - mode=anneal_mode, - verbose=True, - ) + # if scheduler is passed as a class, instantiate + if inspect.isclass(scheduler): + if scheduler == OneCycleLR: + scheduler = OneCycleLR(optimizer, + max_lr=learning_rate, + steps_per_epoch=dataset_size // mini_batch_size + 1, + epochs=max_epochs - epoch, + # if we load a checkpoint, we have already trained for epoch + pct_start=0.0, + cycle_momentum=cycle_momentum) + elif scheduler == LinearSchedulerWithWarmup: + steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size + num_train_steps = int(steps_per_epoch * max_epochs) + num_warmup_steps = int(num_train_steps * warmup_fraction) + + scheduler = LinearSchedulerWithWarmup(optimizer, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps) + else: + scheduler = scheduler( + optimizer, + factor=anneal_factor, + patience=patience, + initial_extra_patience=initial_extra_patience, + mode=anneal_mode, + verbose=True, + ) - if isinstance(lr_scheduler, OneCycleLR) and batch_growth_annealing: + # load existing scheduler state dictionary if it exists + if scheduler_state_dict: + scheduler.load_state_dict(scheduler_state_dict) + + # update optimizer and scheduler in model card + model_card['training_parameters']['optimizer'] = optimizer + model_card['training_parameters']['scheduler'] = scheduler + + if isinstance(scheduler, OneCycleLR) and batch_growth_annealing: raise ValueError("Batch growth with OneCycle policy is not implemented.") train_data = self.corpus.train @@ -375,6 +378,9 @@ def train( for epoch in range(epoch + 1, max_epochs + 1): log_line(log) + # update epoch in model card + self.model.model_card['training_parameters']['epoch'] = epoch + if anneal_with_prestarts: last_epoch_model_state_dict = copy.deepcopy(self.model.state_dict()) @@ -382,9 +388,7 @@ def train( train_part_indices = list(range(self.corpus.train)) random.shuffle(train_part_indices) train_part_indices = train_part_indices[:train_part_size] - train_part = torch.utils.data.dataset.Subset( - self.corpus.train, train_part_indices - ) + train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices) # get new learning rate for group in optimizer.param_groups: @@ -415,7 +419,7 @@ def train( writer.add_scalar("learning_rate", learning_rate, epoch) # stop training if learning rate becomes too small - if ((not isinstance(lr_scheduler, (OneCycleLR, LinearSchedulerWithWarmup)) and + if ((not isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)) and learning_rate < min_learning_rate)): log_line(log) log.info("learning rate too small - quitting training!") @@ -453,16 +457,14 @@ def train( # if necessary, make batch_steps batch_steps = [batch] if len(batch) > micro_batch_size: - batch_steps = [ - batch[x: x + micro_batch_size] - for x in range(0, len(batch), micro_batch_size) - ] + batch_steps = [batch[x: x + micro_batch_size] for x in range(0, len(batch), micro_batch_size)] # forward and backward for batch for batch_step in batch_steps: # forward pass loss = self.model.forward_loss(batch_step) + if isinstance(loss, Tuple): average_over += loss[1] loss = loss[0] @@ -480,8 +482,8 @@ def train( optimizer.step() # do the scheduler step if one-cycle or linear decay - if isinstance(lr_scheduler, (OneCycleLR, LinearSchedulerWithWarmup)): - lr_scheduler.step() + if isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)): + scheduler.step() # get new learning rate for group in optimizer.param_groups: learning_rate = group["lr"] @@ -507,9 +509,7 @@ def train( batch_time = 0 iteration = epoch * total_number_of_batches + batch_no if not param_selection_mode and write_weights: - weight_extractor.extract_weights( - self.model.state_dict(), iteration - ) + weight_extractor.extract_weights(self.model.state_dict(), iteration) if average_over != 0: train_loss /= average_over @@ -517,9 +517,7 @@ def train( self.model.eval() log_line(log) - log.info( - f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}" - ) + log.info(f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}") if use_tensorboard: writer.add_scalar("train_loss", train_loss, epoch) @@ -552,9 +550,8 @@ def train( main_evaluation_metric=main_evaluation_metric, gold_label_dictionary=gold_label_dictionary_for_eval, ) - result_line += ( - f"\t{train_part_loss}\t{train_part_eval_result.log_line}" - ) + result_line += f"\t{train_part_loss}\t{train_part_eval_result.log_line}" + log.info( f"TRAIN_SPLIT : loss {train_part_loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(train_part_eval_result.main_score, 4)}" ) @@ -592,9 +589,7 @@ def train( if use_tensorboard: writer.add_scalar("dev_loss", dev_eval_result.loss, epoch) - writer.add_scalar( - "dev_score", dev_eval_result.main_score, epoch - ) + writer.add_scalar("dev_score", dev_eval_result.main_score, epoch) for (metric_class_avg_type, metric_type) in metrics_for_tensorboard: writer.add_scalar( f"dev_{metric_class_avg_type}_{metric_type}", @@ -622,9 +617,7 @@ def train( if use_tensorboard: writer.add_scalar("test_loss", test_eval_result.loss, epoch) - writer.add_scalar( - "test_score", test_eval_result.main_score, epoch - ) + writer.add_scalar("test_score", test_eval_result.main_score, epoch) for (metric_class_avg_type, metric_type) in metrics_for_tensorboard: writer.add_scalar( f"test_{metric_class_avg_type}_{metric_type}", @@ -639,8 +632,8 @@ def train( current_epoch_has_best_model_so_far = True best_validation_score = dev_score - if isinstance(lr_scheduler, AnnealOnPlateau): - lr_scheduler.step(dev_score, dev_eval_result.loss) + if isinstance(scheduler, AnnealOnPlateau): + scheduler.step(dev_score, dev_eval_result.loss) # alternative: anneal against dev loss if not train_with_dev and anneal_against_dev_loss: @@ -648,8 +641,8 @@ def train( current_epoch_has_best_model_so_far = True best_validation_score = dev_eval_result.loss - if isinstance(lr_scheduler, AnnealOnPlateau): - lr_scheduler.step(dev_eval_result.loss) + if isinstance(scheduler, AnnealOnPlateau): + scheduler.step(dev_eval_result.loss) # alternative: anneal against train loss if train_with_dev: @@ -657,14 +650,14 @@ def train( current_epoch_has_best_model_so_far = True best_validation_score = train_loss - if isinstance(lr_scheduler, AnnealOnPlateau): - lr_scheduler.step(train_loss) + if isinstance(scheduler, AnnealOnPlateau): + scheduler.step(train_loss) train_loss_history.append(train_loss) # determine bad epoch number try: - bad_epochs = lr_scheduler.num_bad_epochs + bad_epochs = scheduler.num_bad_epochs except: bad_epochs = 0 for group in optimizer.param_groups: @@ -704,7 +697,7 @@ def train( # if checkpoint is enabled, save model at each epoch if checkpoint and not param_selection_mode: - self.save_checkpoint(base_path / "checkpoint.pt") + self.model.save(base_path / "checkpoint.pt") # Check whether to save best model if ( @@ -773,17 +766,54 @@ def train( "dev_loss_history": dev_loss_history, } - def save_checkpoint(self, model_file: Union[str, Path]): - corpus = self.corpus - self.corpus = None - torch.save(self, str(model_file), pickle_protocol=4) - self.corpus = corpus - - @classmethod - def load_checkpoint(cls, checkpoint: Union[Path, str], corpus: Corpus): - model: ModelTrainer = torch.load(checkpoint, map_location=flair.device) - model.corpus = corpus - return model + def resume(self, + model: Optional[Model], + **trainer_args, + ): + + self.model = model + + # recover all arguments that were used to train this model + args_used_to_train_model = self.model.model_card['training_parameters'] + + # you can overwrite params with your own + for param in trainer_args: + args_used_to_train_model[param] = trainer_args[param] + if param == 'optimizer' and 'optimizer_state_dict' in args_used_to_train_model: + del args_used_to_train_model['optimizer_state_dict'] + if param == 'scheduler' and 'scheduler_state_dict' in args_used_to_train_model: + del args_used_to_train_model['scheduler_state_dict'] + + # surface nested arguments + kwargs = args_used_to_train_model['kwargs'] + del args_used_to_train_model['kwargs'] + + # resume training with these parameters + self.train(**args_used_to_train_model, **kwargs) + + def fine_tune(self, + base_path: Union[Path, str], + learning_rate: float = 5e-5, + max_epochs: int = 10, + optimizer=torch.optim.AdamW, + scheduler=LinearSchedulerWithWarmup, + warmup_fraction: float = 0.1, + mini_batch_size: int = 4, + embeddings_storage_mode: str = 'none', + **trainer_args, + ): + + return self.train( + base_path=base_path, + learning_rate=learning_rate, + max_epochs=max_epochs, + optimizer=optimizer, + scheduler=scheduler, + warmup_fraction=warmup_fraction, + mini_batch_size=mini_batch_size, + embeddings_storage_mode=embeddings_storage_mode, + **trainer_args, + ) def final_test( self, diff --git a/flair/training_utils.py b/flair/training_utils.py index 7c82b22e1..138283709 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -315,7 +315,7 @@ def state_dict(self): def load_state_dict(self, state_dict): self.__dict__.update(state_dict) - self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) + self._init_is_better(mode=self.mode) def init_output_file(base_path: Union[str, Path], file_name: str) -> Path: diff --git a/tests/test_sequence_tagger.py b/tests/test_sequence_tagger.py index d9566261e..bd6ae7c22 100644 --- a/tests/test_sequence_tagger.py +++ b/tests/test_sequence_tagger.py @@ -345,6 +345,7 @@ def test_train_load_use_tagger_multicorpus(results_base_path, tasks_base_path): @pytest.mark.integration def test_train_resume_tagger(results_base_path, tasks_base_path): + corpus_1 = flair.datasets.ColumnCorpus( data_folder=tasks_base_path / "fashion", column_format={0: "text", 3: "ner"} ) @@ -361,13 +362,16 @@ def test_train_resume_tagger(results_base_path, tasks_base_path): use_crf=False, ) + # train model for 2 epochs trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) - del trainer, model - trainer = ModelTrainer.load_checkpoint(results_base_path / "checkpoint.pt", corpus) + del model - trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) + # load the checkpoint model and train until epoch 4 + checkpoint_model = SequenceTagger.load(results_base_path / "checkpoint.pt") + trainer.resume(model=checkpoint_model, + max_epochs=4) # clean up results directory shutil.rmtree(results_base_path) diff --git a/tests/test_text_classifier.py b/tests/test_text_classifier.py index 76f69d4a8..99d10c719 100644 --- a/tests/test_text_classifier.py +++ b/tests/test_text_classifier.py @@ -263,51 +263,17 @@ def test_train_resume_classifier(results_base_path, tasks_base_path): multi_label=False, label_type="topic") + # train model for 2 epochs trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) - del trainer, model - trainer = ModelTrainer.load_checkpoint(results_base_path / "checkpoint.pt", corpus) - trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) + del model + + # load the checkpoint model and train until epoch 4 + checkpoint_model = TextClassifier.load(results_base_path / "checkpoint.pt") + trainer.resume(model=checkpoint_model, + max_epochs=4) # clean up results directory shutil.rmtree(results_base_path) - del trainer - - -# def test_labels_to_indices(tasks_base_path): -# corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "ag_news", label_type="topic") -# label_dict = corpus.make_label_dictionary() -# model = TextClassifier(document_embeddings, -# label_dictionary=label_dict, -# label_type="topic", -# multi_label=False) -# -# result = model._labels_to_indices(corpus.train) -# -# for i in range(len(corpus.train)): -# expected = label_dict.get_idx_for_item(corpus.train[i].labels[0].value) -# actual = result[i].item() -# -# assert expected == actual -# -# -# def test_labels_to_one_hot(tasks_base_path): -# corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "ag_news", label_type="topic") -# label_dict = corpus.make_label_dictionary() -# model = TextClassifier(document_embeddings, -# label_dictionary=label_dict, -# label_type="topic", -# multi_label=False) -# -# result = model._labels_to_one_hot(corpus.train) -# -# for i in range(len(corpus.train)): -# expected = label_dict.get_idx_for_item(corpus.train[i].labels[0].value) -# actual = result[i] -# -# for idx in range(len(label_dict)): -# if idx == expected: -# assert actual[idx] == 1 -# else: -# assert actual[idx] == 0 + del trainer \ No newline at end of file