From 312a5c2d10089f6fdcfd1290c40bec5ccbbea3df Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 12 May 2024 11:49:49 +0000 Subject: [PATCH] also shuffle here --- wtpsplit/evaluation/llm_sentence.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/wtpsplit/evaluation/llm_sentence.py b/wtpsplit/evaluation/llm_sentence.py index 87c4a4d9..98f005fc 100644 --- a/wtpsplit/evaluation/llm_sentence.py +++ b/wtpsplit/evaluation/llm_sentence.py @@ -194,11 +194,19 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): test_sentences = dataset["data"] if not test_sentences: continue - if isinstance(test_sentences[0], list): - max_n_test_sentences = args.max_n_test_sentences // 10 + if ( + isinstance(test_sentences[0], list) + and "lyrics" not in dataset_name + and "short" not in dataset_name + ): + # documents: only 10% of documents. 1000 sentences --> 100 docs + max_n_sentences = args.max_n_test_sentences // 10 + # shuffle sentences + np.random.seed(42) + test_sentences = np.random.permutation(test_sentences).tolist() else: - max_n_test_sentences = args.max_n_test_sentences - test_sentences = test_sentences[:max_n_test_sentences] + max_n_sentences = args.max_n_test_sentences + test_sentences = test_sentences[:max_n_sentences] if isinstance(test_sentences[0], list): # list of lists: chunk each sublist if "short" in dataset_name or "lyrics" in dataset_name: @@ -458,7 +466,6 @@ def calc_hallucination_deletion_rate(row): if all([char == args.gap_char for char in preds]): # all @ return 0.0, 0.0 - hallucination_count = 0 deletion_count = 0