Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] deepspeed integration #9211

Merged
merged 69 commits into from
Jan 13, 2021
Merged

[trainer] deepspeed integration #9211

merged 69 commits into from
Jan 13, 2021

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Dec 19, 2020

This PR adds experimental support for Deepspeed https://github.com/microsoft/deepspeed, whose main feature is ZeRO covered by the paper ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He.

Recently added support for sharded DDP (fairscale) also implements parts of ZeRO. Deepspeed implements all of ZeRO.

I haven't experimented enough yet, but it indeed delivers incredible results.

For example I can get about a 5-8 times bigger batch onto the same hardware as compared to the same code running w/o deepspeed and the speedup is huge too. In the following example I was able to get 4.5x speedup on training, and ~2x on validation/testing:

# baseline
export BS=3; rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0  python \
-m torch.distributed.launch --nproc_per_node=2  ./finetune_trainer.py --model_name_or_path \
sshleifer/distill-mbart-en-ro-12-4 --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro \
--do_eval --do_predict --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 \
--learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS \
--predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler --src_lang en_XX --task translation \
--test_max_target_length 128 --tgt_lang ro_RO --val_max_target_length 128 --warmup_steps 500 --n_train 2000 \
--n_val 2000 --n_test 2000 

2020-12-18 22:31:40 | INFO | __main__ |   train_runtime = 144.9132
2020-12-18 22:37:10 | INFO | __main__ |   val_runtime = 329.8146
2020-12-18 22:42:37 | INFO | __main__ |   test_runtime = 326.6212

# deepspeed
export BS=20; rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 deepspeed  \
./finetune_trainer.py --model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --output_dir output_dir --adam_eps 1e-06 \
--data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps  --freeze_embeds --label_smoothing 0.1 \
--learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS \
--predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler --src_lang en_XX --task translation \
--test_max_target_length 128 --tgt_lang ro_RO --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 2000 \
--n_test 2000 --deepspeed  ds_config.json 

2020-12-18 22:51:46 | INFO | __main__ |   train_runtime = 32.6825
2020-12-18 22:54:47 | INFO | __main__ |   val_runtime = 180.5917
2020-12-18 22:57:51 | INFO | __main__ |   test_runtime = 183.7731

The bleu eval scores were slightly better than the baseline (~0.5 point higher), but it's not enough to make any conclusions based on a single run.

The cool thing is that deepspeed does everything by itself, even the --fp16 handling, so really it was all about getting out of its way, thus the main part of the integration was to disable a lot of things the trainer does when --deepspeed is enabled.

Note the different invocation pattern. If normally we run distributed as:

python -m torch.distributed.launch --nproc_per_node=2 ./program.py args

deepspeed performs its own DDP internally, and requires the program to be started with:

deepspeed  ./program.py args

The only thing I'm not sure about with this PR is that deepspeed enables all of its features via a json config file, so I'm not sure where to stash a sample one. I guess I will just add it to the documentation. Currently I put one under examples/seq2seq/ds_config.json since that's where the test that needs it lives.

But once this is merged all interested parties can start experimenting with various features, and it won't impact transformers code. They will just need to tweak ds_config.json. And we convert many trainer cl args into DS config on the fly.

There surely will be competition betweeen fairscale and deepspeed integrations. So far from the few experiments I did deepspeed allows for a bigger batch size than fairscale.

To install deepspeed you can just do pip install deepspeed - I'm not sure if all bug fixes are in the release. We can make a request to release a new version when this is merged.

If the build fails I recommend pre-compiling its CUDA extensions (otherwise they get built at run time via PTX) via master:

git clone https://github.com/microsoft/deepspeed
cd deepspeed
DS_BUILD_OPS=1 pip install --no-cache -v --disable-pip-version-check -e . 2>&1 | tee build.log

If you want a faster build add an env var TORCH_CUDA_ARCH_LIST with the cuda compute capabilities you need, e.g. I do:

TORCH_CUDA_ARCH_LIST="6.1;8.6" DS_BUILD_OPS=1 pip install --no-clean --no-cache -v --disable-pip-version-check -e . 2>&1 | tee build.log

It was awesome that @sgugger has just added fairscale support, so it was much easier to do the same for deepspeed seeing how fairscale was integrated, so I'm appreciating the work you have done, Sylvain.


Do try it so we get better testing!

You will need 2+ gpus to use it

First install it:

pip install deepspeed

At the very least do the test:

cd examples/seq2seq
pytest -sv test_finetune_trainer.py -k deepspeed

Or if you want to fiddle with the normal run, here is what I have been using.

cd examples/seq2seq
wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export BS=20; rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 deepspeed --num_gpus=2 ./finetune_trainer.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 5 --n_train 100 --n_val 100 --n_test 100 --deepspeed ds_config.json --fp16 --save_steps 1

Questions that need to be addressed so that all Trainer features continue to work under deepspeed.

  • a notebook with benchmarks was requested

Probably at a later time, my uneven gpu-sized setup doesn't lend to impressive benchmarks - may be someone will send me another rtx-3090 card ;)

@sgugger, @LysandreJik, @patrickvonplaten

args=args,
model=model,
model_parameters=model_parameters,
# optimizer=optimizer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are those commented-out statements not needed anymore?

Copy link
Contributor Author

@stas00 stas00 Dec 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are options I haven't explored yet, so they are there to see that this can be done.

                # optimizer=optimizer,
                # lr_scheduler=lr_scheduler,
                # training_data=trainset,

The first 2 are there in case DS doesn't have a particular scheduler/optimizer in its toolkit and the user wants to pass their own, but I haven't tried these yet.

Wrt the 3rd one it seems that our trainer already handles the batching, so I'm not sure if we need to delegate this feature to DS or not. I may experiment with it later, but it will require even more interfering.

Unlike fairscale's sharded features, DS has dozens of features, and exploring each is a process on its own. My main intention was to get the bulk off the road and then multiple devs can explore various sub-features.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! Very little changes to existing code for an awesome new feature!
Only nit from my side would be to raise instead of silently disabling fp16 in "deepspeed" mode.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging into this @stas00 !
Even if this is an experimental feature, if we start putting into Trainer, people are going to want to all the Trainer features to work with it so I have two questions:

  • how is the optimizer/scheduler creation handled? E.g. how is the fact we don't apply weight decay to some parameters or the proper schedule with the number of training steps handled? I don't see it in the current version of the code since deepspeed is responsible for creating the optimizer and scheduler
  • how is the checkpoint of the optimizer and lr_scheduler handled? And if we need to reload them become we are resuming a previous training, does it work?

Last point is that we should save model.module in self.model when using deepspeed.

examples/seq2seq/seq2seq_trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
@@ -217,6 +217,10 @@ class TrainingArguments:
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.
deepspeed (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature.
deepspeed_config (:obj:`str`, `optional`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
deepspeed_config (:obj:`str`, `optional`):
deepspeed_config_file (:obj:`str`, `optional`):

This makes it clearer to me we should pass along a file and not a config object (which the library has plenty of).

Copy link
Contributor Author

@stas00 stas00 Dec 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. deepspeed.initialize expects to find args.deepspeed_config so if we follow your suggestion we will have to rewrite that key before passing args to deepspeed.initialize.

  2. As I mentioned elsewhere I think it'd be sufficient to just have a single argument deepspeed and have its value to be the config file and then re-assign it to args.deepspeed_config before deepspeed.initialize .

But either your suggestion or mine will break the deepspeed convention of here is how one runs deepspeed:

deepspeed myprog myargs --deepspeed --deepspeed_config ds_config.js

so it'd be slightly confusing to users.

Copy link
Collaborator

@sgugger sgugger Dec 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see if we keep the arguments as is or if we re-wrap them for deepspeed. I like having only one arg that is deepspeed_config_file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like args.local_rank is needed too, please have a look at how I made it clear what's being passed to deepspeed:
9cc3b63

Copy link
Contributor Author

@stas00 stas00 Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So since I'm rewrapping them anyway - it's your call now - I can generate the 2 vars on the fly based on just deepspeed_config_file you suggested - though it is kind of an odd name for double function. From user's point of view If I pass a config file, does it also activate the feature? I suppose this is why they had 2 vars. Not sure.

@patrickvonplaten, what do you think?

See my comment here #9211 (comment) for a quick ramp up on what we are talking about. And then we are considering to collapse these 2 into a single cl arg that will provide both - the config file and at the same time activate deepspeed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also asked for suggestions at microsoft/DeepSpeed#616

Copy link
Contributor Author

@stas00 stas00 Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I got feedback and no, there is no need to use both cl args, in fact, --deepspeed is not needed at all as long as we call deepspeed.initialize.

So let's use a single cl arg.

Let's just decide how to best name it.

So the proposals so far:

  1. I propose --deepspeed ds_config.js. While it's less obvious that it expects a file name argument, it's unambiguous about it activating deepspeed.
  2. We could make the value even optional and default to ds_config.js, so most of the time it'd be just --deepspeed
  3. Your proposal @sgugger was --deepspeed_config_file but that was in combination with --deepspeed

What are your thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, until everybody is back I changed it to just one cl arg: --deepspeed ds_config.js. If you prefer a different name please let me know - should be easy to rename.

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
@jeffra
Copy link
Contributor

jeffra commented Dec 22, 2020

Thanks @stas00 for putting this together! I think there might be a few things we can do on the deepspeed side to smooth a few pain points out. Most of us our out of office until the new year but will definitely be taking a close look at this soon and help where we can.

@stas00
Copy link
Contributor Author

stas00 commented Dec 22, 2020

Thank you, @jeffra!

New year sounds perfect as a time for you to make suggestions if any, but meanwhile I think it's coming along nicely.

@stas00
Copy link
Contributor Author

stas00 commented Dec 23, 2020

OK, so this PR also introduces a concept of self.wrapped_model so that we have less confusion about which is which in the trainer and user code.

  • self.model is always transformers model - this is the internal model
  • self.wrapped_model is a wrapped model - always the most external model which could be DDP(Transformers Model), DDP(Deepspeed(Transformers Model)), etc.

It's not documented yet, but @sgugger when you get a chance could you please check that it looks correct what I did here 1510444

Questions:

  1. is it correct that I set it to None if there is no wrapped model?
  2. would it be better to call it model_wrapped - so the two better align side by side during debug or IDE prediction completion engines?
  3. I'm not sure where to document this? And we should probably add a public API accessor?
  4. We can now probably remove and refactor the following code, as we now have a simpler way to get the internal model -
    def _actual_model(
    model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
    ) -> torch.nn.modules.Module:
    """
    Args:
    model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
    Model object used during training
    Returns:
    :obj:`torch.nn.modules.Module`: unwrapped module
    """
    if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
    model = model.module
    else:
    model = model
    return model

    or is it used in some deep code where there is no trainer object? In which case this code won't work as it needs double unwrapping under deepspeed - model.module.module (since we have DDP too).

I see it's used only in floating_point_ops which does have access to self so I'm not sure why it was needed in first place. Also in 2 tests, but that could be moved into the tests if need be.

If we want a general unwrap function it needs to do it recursively until there is no more .module.

@sgugger
Copy link
Collaborator

sgugger commented Dec 23, 2020

Let me think.
For 1, I think the wrapped_model should be the model, in this case, just to avoid the inconvenience of testing if None.
For 2, I have no strong opinion, so you can pick the version you prefer.
For 3, none of the attributes of the Trainer are properly documented yet. This could be added in the main docstring
For 4, yes, absolutely. This was a quick fix that was merged when I didn't get much time to do a nice solution, I though I had removed all uses of that function.

Could you add the wrapped_model or model_wrapped in a separate PR? This would be easier to follow and not highjack the discussion on the deepspeed integration. We can rebase this when that PR is merged.

@stas00
Copy link
Contributor Author

stas00 commented Dec 23, 2020

For 1, I think the wraped_model should be the model, in this case, just to avoid the inconvenience of testing if None.

Then we somewhat lose information - None is telling us that notihng is wrapping the model. But I suppose we could achieve the same by self.model == self.wrapped_mode - OK, that works!

Thank you for the rest of the answers, @sgugger. Will integrate those and make a separate PR with wrapped_model.

@g-karthik
Copy link

@stas00 Should it be mentioned in some README/documentation that folks can only use DeepSpeed with the PyTorch Trainer and not the TF Trainer? There's a hard dependency of using torch.distributed with the NCCL backend to use DeepSpeed.

Secondly, what is the plan in terms of introducing DeepSpeed in the transformers setup.py and the PyTorch GPU Dockerfile?

While DeepSpeed has a pip installable PyPI package, IIRC it is highly recommended that it be installed from source. Also, in order to use certain features in DeepSpeed such as 1-bit Adam, there are certain special installations to be done that do not come with the PyPI package. Will this PR support every underlying DeepSpeed feature? If not, can the scope of the initial DeepSpeed integration be defined clearly in some README/documentation, while allowing for further iterations in future to enable the utilization of more DeepSpeed features with the transformers Trainer?

@stas00
Copy link
Contributor Author

stas00 commented Dec 24, 2020

@stas00 Should it be mentioned in some README/documentation that folks can only use DeepSpeed with the PyTorch Trainer and not the TF Trainer? There's a hard dependency of using torch.distributed with the NCCL backend to use DeepSpeed.

Yes, we should definitely be clear about that. thank you!

At the moment the idea is to put all the ZeRO related docs here: #9208 (that PR covers fairscale at the moment)

Secondly, what is the plan in terms of introducing DeepSpeed in the transformers setup.py

It'll be up to users to install deepspeed, just like it's the case with fairscale or any other libraries the transformers core doesn't require. Currently if you use --deepspeed and you don't have it installed the trainer will assert with a suggestion to install that library.

and the PyTorch GPU Dockerfile?

I have no idea. I don't see any reason why it can't be included.

Let's do it in baby steps. First, make the support available, test it out, solve initial issues if any. Then worry about everything else?

While DeepSpeed has a pip installable PyPI package, IIRC it is highly recommended that it be installed from source. Also, in order to use certain features in DeepSpeed such as 1-bit Adam, there are certain special installations to be done that do not come with the PyPI package. Will this PR support every underlying DeepSpeed feature? If not, can the scope of the initial DeepSpeed integration be defined clearly in some README/documentation, while allowing for further iterations in future to enable the utilization of more DeepSpeed features with the transformers Trainer?

As I have shown in the example of the upcoming fairscale-support doc PR (we are waiting for fairscale to make a new pypi release before we merge it), we will document the same for DeepSpeed and address your questions. Your comments would be super-helpful for that document, so please save them for when we get to write that document. With your permission I can tag you on that future PR. Thank you.

Wrt to specifics let's see what ends up working out of box and what needs to be polished. I think the main issues will be bugs on the DS side. Otherwise there is a ton of features and I have only been testing a few.

If you feel inspired and are already experienced with DS it'd be awesome if you made a checklist of features and then between you and I, and anybody else who wants to contribute, test those features and check what's supported and report back to DS what is not. Since DS does most of the things on its own, I don't think there will be much to change in transformers Trainer once this PR is polished. I can be wrong of course.

edit: actually there is no point waiting - I started adding notes into docs/source/training.rst in this PR. So already added a few of your comments - will need to expand those later.

@stas00
Copy link
Contributor Author

stas00 commented Jan 8, 2021

deepspeed-0.3.10 has been just released by @jeffra on pypi - I verified that it works - so we are ready to merge this whenever you're happy with it.

it'd be great if you tried running it too, since I think it has been only me running it, so my work is only as good as my environment is and I may not know of other culprits - e.g. I can't test with pytorch < pt-nightly since my card doesn't work with those pytorch versions.

You will need 2+ gpus to use it

First install it:

pip install deepspeed

At the very least do the test:

cd examples/seq2seq
pytest -sv test_finetune_trainer.py -k deepspeed

Or if you want to fiddle with the normal run, here is what I have been using.

cd examples/seq2seq
wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export BS=20; rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 deepspeed --num_gpus=2 ./finetune_trainer.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 5 --n_train 100 --n_val 100 --n_test 100 --deepspeed ds_config.json --fp16 --save_steps 1

@exelents exelents mentioned this pull request Jan 9, 2021
@stas00
Copy link
Contributor Author

stas00 commented Jan 10, 2021

@sgugger,

  1. While working on docs I discovered that DS does its own gradient clipping (that doc was buried and I didn't see it) so I had to undo the code in the trainer that did that on behalf of DS - just skipping it
  2. I did a major rewrite/expansion of the docs (including the fairscale section) - so please kindly have a look. It's mainly mirroring the config logic in the integration code.
  3. In the docs I used consistently Trainer (upcase) to refer to HF trainer. I know you didn't like it when I did that for Issue in a different PR, let me know if you prefer it to be a lowercase trainer.

While this PR is perfectly ready for a final review, I need to wait for microsoft/DeepSpeed#656 to be answered before we can merge this as I'm unsure about their defaults for gradient clipping.

Thank you.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went through the documentation and left comments. On the optimizer side, it doesn't seem like DeepSpeed supports AdamW from what you're saying, so we should document the default optimizer is changed at the very beginning of the DeepSpeed section. It does change drastically the value of weight_decay to use.

docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Show resolved Hide resolved
@stas00
Copy link
Contributor Author

stas00 commented Jan 11, 2021

Went through the documentation and left comments.

Awesome - thank you - all integrated.

On the optimizer side, it doesn't seem like DeepSpeed supports AdamW from what you're saying, so we should document the default optimizer is changed at the very beginning of the DeepSpeed section. It does change drastically the value of weight_decay to use.

I found a way to use AdamW, thank you for catching that, @sgugger. I documented the nuances.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Thanks for your work on this @stas00!

docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved

**Optimizer:**

DeepSpeed has several tested with ZeRO optimizers, which are Adam, OneBitAdam, and Lamb. It, however, can import other
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the first sentence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means that "It has tested only these optimizers to properly work with ZeRO". I will rewrite it not to use passive and it will be then straightforward.

Copy link
Contributor Author

@stas00 stas00 Jan 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote it as:

DeepSpeed's main optimizers are Adam, OneBitAdam, and Lamb. These have been thoroughly tested with ZeRO and are thus recommended to be used. It, however, can import other optimizers from torch.

Please let me know if it's still unclear.

docs/source/main_classes/trainer.rst Outdated Show resolved Hide resolved
stas00 and others added 2 commits January 12, 2021 09:43
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
@stas00
Copy link
Contributor Author

stas00 commented Jan 13, 2021

I think the DeepSpeed team is on vacation, as there is no response since several days. And since I have no way of talking to anyone there, I have no way of knowing when they will be back. So I will go ahead and merge this so that others can start experimenting and then we can fix whatever needs to be fixed when I get the gradient clipping Issue answered.

@stas00 stas00 merged commit 2df34f4 into huggingface:master Jan 13, 2021
@stas00 stas00 deleted the ds branch January 13, 2021 03:05
@Narsil Narsil mentioned this pull request Jan 13, 2021
3 tasks
@patrickvonplaten
Copy link
Contributor

Amazing work @stas00 !

guyrosin pushed a commit to guyrosin/transformers that referenced this pull request Jan 15, 2021
* deepspeed integration

* style

* add test

* ds wants to do its own backward

* fp16 assert

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* for clarity extract what args are being passed to deepspeed

* introduce the concept of self.wrapped_model

* s/self.wrapped_model/self.model_wrapped/

* complete transition to self.wrapped_model / self.model

* fix

* doc

* give ds its own init

* add custom overrides, handle bs correctly

* fix test

* clean up model_init logic, fix small bug

* complete fix

* collapse --deepspeed_config into --deepspeed

* style

* start adding doc notes

* style

* implement hf2ds optimizer and scheduler configuration remapping

* oops

* call get_num_training_steps absolutely when needed

* workaround broken auto-formatter

* deepspeed_config arg is no longer needed - fixed in deepspeed master

* use hf's fp16 args in config

* clean

* start on the docs

* rebase cleanup

* finish up --fp16

* clarify the supported stages

* big refactor thanks to discovering deepspeed.init_distributed

* cleanup

* revert fp16 part

* add checkpoint-support

* more init ds into integrations

* extend docs

* cleanup

* unfix docs

* clean up old code

* imports

* move docs

* fix logic

* make it clear which file it's referring to

* document nodes/gpus

* style

* wrong format

* style

* deepspeed handles gradient clipping

* easier to read

* major doc rewrite

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* docs

* switch to AdamW optimizer

* style

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* clarify doc

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants