diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index 7b995ec1..4dff0684 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -291,6 +291,8 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st if "test_logits" not in dset_group: test_sentences = dataset["data"][: args.max_n_test_sentences] + if not test_sentences: + continue if isinstance(test_sentences[0], list): continue all_pairs_test = generate_k_mers( @@ -428,6 +430,8 @@ def main(args): for dataset_name, dataset in dsets["sentence"].items(): sentences = dataset["data"][: args.max_n_test_sentences] + if not sentences: + continue if isinstance(sentences[0], list): continue sent_k_mers = generate_k_mers( diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index cd0c9d22..ea9cad87 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -105,6 +105,7 @@ class Args: lookahead_split_layers: Optional[int] = None sample_non_whitespace: int = 1 + def collate_fn(batch, args, label_args, label_dict, tokenizer, add_lang_ids: bool = False): all_input_ids = [] all_labels = [] @@ -585,6 +586,11 @@ def compute_metrics(trainer): for dataset_name, dataset in lang_data["sentence"].items(): # if "corrupt" in dataset_name: # continue + if not dataset["data"][0]: + continue + + if isinstance(dataset["data"][0], list): + continue score, info = evaluate_sentence( lang_code, dataset["data"],