Skip to content

Commit

Permalink
Bug fix: Switch order of OEEvalTask and OEEvalTaskWithNewlin
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Sep 15, 2024
1 parent 48b17e2 commit 499ac18
Showing 1 changed file with 61 additions and 61 deletions.
122 changes: 61 additions & 61 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,7 +1579,55 @@ def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"

class OEEvalTaskWithNewline(OEEvalTask):
class OEEvalTask(ICLMultiChoiceTaskDataset):
"""Generic class for OE evaluation tasks"""

def __init__(
self,
tokenizer: Tokenizer,
dataset_path: str,
dataset_name: Union[str, Sequence[str], None] = None,
model_ctx_len: int = 2048,
split=None,
metric_type=None,
prompts=[None], # List of prompt variants to use
):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.model_ctx_len = model_ctx_len
self.log_instances = 0 # Set to > 0 to log the first few instances as a sanity check

self.samples: List[Dict[str, Any]] = []
dataset_names: Sequence[Optional[str]]
if isinstance(dataset_name, str) or dataset_name is None:
dataset_names = [dataset_name]
else:
dataset_names = dataset_name

requests_list = []
configs = []
for ds_name in dataset_names:
config, requests = load_oe_eval_requests(self.dataset_path, ds_name, split)
requests_list.append(requests)
configs.append(config)
if metric_type is not None:
self.metric_type = metric_type
else:
# Use metric type from associated task config
for config in configs:
if config is not None:
metric_type_raw = config["task_config"].get("primary_metric")
if metric_type_raw is not None:
# acc, len_norm, pmi_dc
metric_type = METRIC_FROM_OE_EVAL[metric_type_raw]
if self.metric_type is not None and self.metric_type != metric_type:
raise ValueError(f"Conflicting metric types: {self.metric_type} and {metric_type}")
self.metric_type = metric_type
self.dataset = requests_list

# prep examples
self.prep_examples()

def prep_examples(self):
current_doc_id_offset = 0
Expand Down Expand Up @@ -1610,8 +1658,6 @@ def prep_examples(self):

request_dict = request["request"]
continuation_str = request_dict["continuation"]
if continuation_str.startswith(' '):
continuation_str = continuation_str.lstrip(' ')
label_id = request["label"]
cont_id = request["idx"]
if self.metric_type in ["ce_loss", "bpb"]:
Expand All @@ -1623,9 +1669,6 @@ def prep_examples(self):
cont_id = 0
label_id = 0
doc_text = request_dict["context"]
doc_text.replace('Answer: ', 'Answer:\n')
if not doc_text.endswith('\n'):
doc_text += '\n'
ctx = self.token_encode(doc_text)
dc = self.token_encode(self.doc_to_domain_conditional(doc))
if self.log_instances > 0:
Expand Down Expand Up @@ -1671,56 +1714,17 @@ def prep_examples(self):
}
)

def doc_to_text(self, doc) -> str:
raise NotImplementedError

class OEEvalTask(ICLMultiChoiceTaskDataset):
"""Generic class for OE evaluation tasks"""

def __init__(
self,
tokenizer: Tokenizer,
dataset_path: str,
dataset_name: Union[str, Sequence[str], None] = None,
model_ctx_len: int = 2048,
split=None,
metric_type=None,
prompts=[None], # List of prompt variants to use
):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.model_ctx_len = model_ctx_len
self.log_instances = 0 # Set to > 0 to log the first few instances as a sanity check
def doc_to_continuations(self, doc) -> List[str]:
raise NotImplementedError

self.samples: List[Dict[str, Any]] = []
dataset_names: Sequence[Optional[str]]
if isinstance(dataset_name, str) or dataset_name is None:
dataset_names = [dataset_name]
else:
dataset_names = dataset_name
def doc_to_label(self, doc) -> int:
raise NotImplementedError

requests_list = []
configs = []
for ds_name in dataset_names:
config, requests = load_oe_eval_requests(self.dataset_path, ds_name, split)
requests_list.append(requests)
configs.append(config)
if metric_type is not None:
self.metric_type = metric_type
else:
# Use metric type from associated task config
for config in configs:
if config is not None:
metric_type_raw = config["task_config"].get("primary_metric")
if metric_type_raw is not None:
# acc, len_norm, pmi_dc
metric_type = METRIC_FROM_OE_EVAL[metric_type_raw]
if self.metric_type is not None and self.metric_type != metric_type:
raise ValueError(f"Conflicting metric types: {self.metric_type} and {metric_type}")
self.metric_type = metric_type
self.dataset = requests_list

# prep examples
self.prep_examples()
class OEEvalTaskWithNewline(OEEvalTask):

def prep_examples(self):
current_doc_id_offset = 0
Expand Down Expand Up @@ -1751,6 +1755,8 @@ def prep_examples(self):

request_dict = request["request"]
continuation_str = request_dict["continuation"]
if continuation_str.startswith(' '):
continuation_str = continuation_str.lstrip(' ')
label_id = request["label"]
cont_id = request["idx"]
if self.metric_type in ["ce_loss", "bpb"]:
Expand All @@ -1762,6 +1768,9 @@ def prep_examples(self):
cont_id = 0
label_id = 0
doc_text = request_dict["context"]
doc_text.replace('Answer: ', 'Answer:\n')
if not doc_text.endswith('\n'):
doc_text += '\n'
ctx = self.token_encode(doc_text)
dc = self.token_encode(self.doc_to_domain_conditional(doc))
if self.log_instances > 0:
Expand Down Expand Up @@ -1807,15 +1816,6 @@ def prep_examples(self):
}
)

def doc_to_text(self, doc) -> str:
raise NotImplementedError

def doc_to_continuations(self, doc) -> List[str]:
raise NotImplementedError

def doc_to_label(self, doc) -> int:
raise NotImplementedError


class Vera(ICLMultiChoiceTaskDataset):

Expand Down

0 comments on commit 499ac18

Please sign in to comment.