Skip to content

Commit

Permalink
finally fix indices
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 19, 2024
1 parent 29e4d06 commit 98443b7
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.evaluation.intrinsic_baselines import split_language_data
from wtpsplit.extract import PyTorchWrapper
from wtpsplit.extract_batched import extract_batched
from wtpsplit.utils import Constants
Expand Down Expand Up @@ -241,11 +242,10 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
total_test_time = 0 # Initialize total test processing time

start_time = time.time()
with h5py.File(logits_path, "a") as f, torch.no_grad():
with h5py.File(logits_path, "w") as f, torch.no_grad():
for lang_code in Constants.LANGINFO.index:
if args.include_langs is not None and lang_code not in args.include_langs:
continue

print(f"Processing {lang_code}...")
if lang_code not in f:
lang_group = f.create_group(lang_code)
Expand All @@ -254,8 +254,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
if args.skip_corrupted and "corrupted" in dataset_name and"ted2020" not in dataset_name:
if args.skip_corrupted and "corrupted" in dataset_name and "ted2020" not in dataset_name:
continue
if "-" in lang_code and "canine" in args.model_path and "no-adapters" not in args.model_path:
# code-switched data: eval 2x
lang_code = lang_code.split("_")[1].lower()
try:
if args.adapter_path:
model.model.load_adapter(
Expand Down Expand Up @@ -377,6 +380,8 @@ def main(args):

print(save_str)
eval_data = torch.load(args.eval_data_path)
if "canine" in args.model_path and not "no-adapters" in args.model_path:
eval_data = split_language_data(eval_data)
if args.valid_text_path is not None:
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
else:
Expand Down Expand Up @@ -530,7 +535,9 @@ def main(args):
acc_t = np.mean(acc_t) if score_t else None
acc_punct = np.mean(acc_punct) if score_punct else None
threshold = np.mean(thresholds)

u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else [])
true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else [])
length.append(cur_u_indices["length"])

results[lang_code][dataset_name] = {
"u": score_u,
Expand Down Expand Up @@ -596,7 +603,7 @@ def main(args):
),
indent=4,
)

if args.return_indices:
json.dump(
indices,
Expand Down

0 comments on commit 98443b7

Please sign in to comment.