Skip to content

Commit

Permalink
finally fix idcs
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 20, 2024
1 parent d875c77 commit 6a09358
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
total_test_time = 0 # Initialize total test processing time

start_time = time.time()
with h5py.File(logits_path, "w") as f, torch.no_grad():
with h5py.File(logits_path, "a") as f, torch.no_grad():
for lang_code in Constants.LANGINFO.index:
if args.include_langs is not None and lang_code not in args.include_langs:
continue
Expand Down Expand Up @@ -525,6 +525,9 @@ def main(args):
)
score_u.append(single_score_u)
acc_u.append(info["info_newline"]["correct_pairwise"])
u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else [])
true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else [])
length.append(cur_u_indices["length"])

score_u = np.mean(score_u)
score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None
Expand All @@ -535,9 +538,7 @@ def main(args):
acc_t = np.mean(acc_t) if score_t else None
acc_punct = np.mean(acc_punct) if score_punct else None
threshold = np.mean(thresholds)
u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else [])
true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else [])
length.append(cur_u_indices["length"])


results[lang_code][dataset_name] = {
"u": score_u,
Expand Down

0 comments on commit 6a09358

Please sign in to comment.