diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index fe9f8d128..7e7d893bd 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -20,6 +20,7 @@ import torch.nn.functional as F import transformers from datasets.arrow_dataset import Dataset +from datasets.iterable_dataset import IterableDataset from datasets.load import load_dataset from huggingface_hub import hf_hub_download from jaxtyping import Float, Int @@ -265,13 +266,15 @@ def keep_single_column(dataset: Dataset, col_name: str): def tokenize_and_concatenate( - dataset: Dataset, + dataset: Union[Dataset, IterableDataset], tokenizer: AutoTokenizer, streaming: bool = False, max_length: int = 1024, column_name: str = "text", add_bos_token: bool = True, num_proc: int = 10, + remove_pad_tokens: bool = True, + set_format: bool = True, ) -> Dataset: """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. @@ -284,6 +287,10 @@ def tokenize_and_concatenate( max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. add_bos_token (bool, optional): . Defaults to True. + num_proc (int, optional): The number of processes to use for parallel tokenization. Defaults + to 10. + remove_pad_tokens (bool, optional): Whether to remove the padding tokens. Defaults to True. + set_format (bool, optional): Whether to set the format of the dataset to torch and remove a column. Defaults to True. Returns: Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" @@ -310,8 +317,9 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() - # Drop padding tokens - tokens = tokens[tokens != tokenizer.pad_token_id] + if remove_pad_tokens: + # Drop padding tokens + tokens = tokens[tokens != tokenizer.pad_token_id] num_tokens = len(tokens) num_batches = num_tokens // (seq_len) # Drop the final tokens if not enough to make a full sequence @@ -327,10 +335,23 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: tokenized_dataset = dataset.map( tokenize_function, batched=True, - num_proc=(num_proc if not streaming else None), remove_columns=[column_name], + # Don't even pass the num_proc argument if we're streaming + **({"num_proc": num_proc} if not streaming else {}), ) - tokenized_dataset.set_format(type="torch", columns=["tokens"]) + + if set_format: + # This cleans up the dataset, removing the column name and setting the format to torch + # Doesn't work for all datasets (eg when streaming) + # Creating a generator which will be lazily loaded, e.g + # ```` + # formatted_tokenized_dataset = (torch.LongTensor(example['tokens']) for example in + # tokenized_dataset) + # ```` + # May work for your use case + + tokenized_dataset.set_format(type="torch", columns=["tokens"]) + return tokenized_dataset