Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
igorsterner authored Aug 3, 2024
1 parent a164db1 commit 831b184
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions wtpsplit/train/train_SM.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
@dataclass
class Args:
block_size: int = 256
num_layers: int = 12
lim_lookahead: bool = False
without_pretraining: bool = False
no_sm_corruption: bool = False
num_layers: int = 12 # number of layers
lim_lookahead: bool = False # our "Lookahead" ablation
without_pretraining: bool = False # our "No pre-training" ablation
no_sm_corruption: bool = False # our "Only clean text" ablation

# Parsing command line arguments or JSON config files as needed
parser = HfArgumentParser([Args, TrainingArguments])
Expand All @@ -46,19 +46,20 @@ class Args:
punct_chars = set(Constants.PUNCTUATION_CHARS)


for lang_code in tqdm(all_data, desc="Loading train/dev data"):
for lang_code in tqdm(all_data, desc="Loading data"):
if "-" in lang_code or "_" in lang_code:
# we only train on monolingual data in SM, so no "en-de" code-switching for example!
pass
elif (
"ud" in all_data[lang_code]["sentence"]
and all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] is not None
):
train_data = all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"]
# cf. Appendix A.2

if len(train_data) < 10000:
train_data = train_data * (10000 // len(train_data) + 1)

if len(train_data) < 5000:
# some languages have an insufficient number of sentences to fill a single batch
# this is just a quick way to upsample these so we don't run into problems later
# later we will use a uniform round-robin sampler for all languages
train_data = train_data * (10000 // len(train_data) + 1)

train_sentences[lang_code]["uncorrupted"].extend(train_data)
Expand All @@ -67,13 +68,19 @@ class Args:
train_data = all_data[lang_code]["sentence"]["ud-corrupted-asr"]["meta"]["train_data"]

if len(train_data) < 5000:
# some languages have an insufficient number of sentences to fill a single batch
# this is just a quick way to upsample these so we don't run into problems later
# later we will use a uniform round-robin sampler for all languages
train_data = train_data * (10000 // len(train_data) + 1)

train_sentences[lang_code]["corrupted-asr"].extend(train_data)

train_data = all_data[lang_code]["sentence"]["ud-corrupted-social-media"]["meta"]["train_data"]

if len(train_data) < 5000:
# some languages have an insufficient number of sentences to fill a single batch
# this is just a quick way to upsample these so we don't run into problems later
# later we will use a uniform round-robin sampler for all languages
train_data = train_data * (10000 // len(train_data) + 1)

train_sentences[lang_code]["corrupted-social-media"].extend(train_data)
Expand All @@ -83,29 +90,23 @@ class Args:
and all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"] is not None
):
train_data = all_data[lang_code]["sentence"]["opus100"]["meta"]["train_data"]
assert len(train_data) == 10000
train_sentences[lang_code]["uncorrupted"].extend(train_data)

if not args.no_sm_corruption:
train_data = all_data[lang_code]["sentence"]["opus100-corrupted-asr"]["meta"]["train_data"]
assert len(train_data) == 10000
train_sentences[lang_code]["corrupted-asr"].extend(train_data)

train_data = all_data[lang_code]["sentence"]["opus100-corrupted-social-media"]["meta"]["train_data"]
assert len(train_data) == 10000
train_sentences[lang_code]["corrupted-social-media"].extend(train_data)
else:
train_data = all_data[lang_code]["sentence"]["nllb"]["meta"]["train_data"]
assert len(train_data) == 10000
train_sentences[lang_code]["uncorrupted"].extend(train_data)

if not args.no_sm_corruption:
train_data = all_data[lang_code]["sentence"]["nllb-corrupted-asr"]["meta"]["train_data"]
assert len(train_data) == 10000
train_sentences[lang_code]["corrupted-asr"].extend(train_data)

train_data = all_data[lang_code]["sentence"]["nllb-corrupted-social-media"]["meta"]["train_data"]
assert len(train_data) == 10000
train_sentences[lang_code]["corrupted-social-media"].extend(train_data)

for dataset in all_data[lang_code]["sentence"]:
Expand Down Expand Up @@ -145,7 +146,6 @@ class Args:
model_checkpoint = "segment-any-text/sat-12l-no-limited-lookahead"
else:
model_checkpoint = "segment-any-text/sat-12l"

else:
raise ValueError("Invalid number of layers. Valid values are 1, 3, 6, 9, 12.")

Expand All @@ -155,6 +155,7 @@ class Args:
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

if args.num_layers == 3 and args.without_pretraining:
# special case for one of our ablations, where we trim XLM-R (without any of our newline pretraining) to 3 layers
model = SubwordXLMForTokenClassification.from_pretrained(
model_checkpoint,
num_labels=1,
Expand Down Expand Up @@ -299,8 +300,6 @@ def pack_sentences(input_data_dict, block_size):

experiment_name = model_checkpoint.split("/")[-1]

# experiment_name += str(args.num_layers) + "L"

if args.no_sm_corruption:
experiment_name += "-no-corruption"

Expand Down

0 comments on commit 831b184

Please sign in to comment.