Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5 model parallel] implement input auto-relocation + lots of refactoring/cleanup #9323

Closed
wants to merge 7 commits into from

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Dec 28, 2020

As I commented on in another incarnation of generalizing t5 model parallelism #9316 so that it could be easily ported to other models I realized that it's quite unnecessary to try and remap inputs to specific devices where they will be needed in the future ahead of time. Since we have forward where we have access to the device of the parameters of that layer - we can completely automate the relocation of inputs to the correct devices just before forward is called. So this PR builds upon #9316 and:

  • creates @model_parallel_inputs_to_device decorator used for forward, which automatically takes any inputs and puts them on the same device as the parameters of that layer. This allowed a complete removal of most of the .to() juggling logic for inputs, which was quite complex and noisy.
  • a lot of refactoring to make the MP as little invasive and noisy as possible, and fixing some small issues on the way.

I have tested this with:

pyt -sv tests/test_modeling_t5.py -k parallel

Which I'm not sure covers all bases, but the above tests pass.

@alexorona, please let me know what you think. And if you have real applications besides the great tests you wrote please see if it still works correctly. (It was so awesome having those tests in place! Thank you!) If it looks good and others support this proposal we can then look at doing the same for gpt2 and meanwhile I will look at bart.

@patrickvonplaten, @LysandreJik

return device_map


def model_parallel_inputs_to_device(func):
Copy link
Contributor Author

@stas00 stas00 Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forward decorator is the key element of this PR.

@patrickvonplaten
Copy link
Contributor

I don't have an in-depth knowledge of our model parallelism features, so it would be great if @LysandreJik can take a look here as well.

I think in general, I'm in favor of this PR. However, I'm not sure if a function decorator is better than just having two lines of

if self.is_parallel: 
       # call map to device function

in the respective forward function. We've decided against using function decorators in Pytorch at multiple points (gradient checkpointing e.g.), so I'm not convinced it's the better option to do it here. Function decorators do reduce code readability quite a lot IMO.

@stas00
Copy link
Contributor Author

stas00 commented Dec 28, 2020

I'm not sure how your suggestion would work since it needs to be generic, and once inside forward the function args are no longer generic. Remember, I'm trying to build a generic functionality that can work out of the box in any transformers model and not specific to t5.

The other approach that doesn't need a decorator is to override self.__call__ via self.parallelize to set to a variation of this wrapper.

    def parallelize(self, device_map=None):
        $self.__call__ = model_parallel__call__
        [...]

    def deparallelize(self):
        $self.__call__ = nn.Module.__call__
        [...]

and:

def model_parallel__call__(self, *input, **kwargs):
    
    # get device of any of the params of this layer
    try:
        device = next(self.parameters(recurse=True)).device
    except StopIteration:
        device = None

    if device is not None:

        input = list(input)
        for i, v in enumerate(input):
            if v is not None:
                input[i] = v.to(device)
        input = tuple(input)

        for k in kwargs.keys():
            if kwargs[k] is not None and torch.is_tensor(kwargs[k]):
                kwargs[k] = kwargs[k].to(device)

    return nn.Module.__call__(self, *input, **kwargs)

(or could save the original self.__call__ to be more flexible and to allow for others to override this too)

this in fact is even better since it will have 0 impact on non-MP functionality as this wrapper will be called only under MP.

@@ -773,41 +784,34 @@ def __init__(self, config, embed_tokens=None):
self.dropout = nn.Dropout(config.dropout_rate)

self.init_weights()
# Model parallel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't something like:

if self.model_parallel:
        hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value = _call__mp(hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value)

directly under the function signature work?

But I agree that it's a lot of boilerplate code ... -> maybe the best way is indeed to use a decorator function here.
Think in general I'm not deep enough into the model parallelism to have a good view here. Let's see what @LysandreJik thinks of the design :-).

Copy link
Contributor Author

@stas00 stas00 Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see - you are suggesting to revert it to the original way where each input was manually switched to its destination device, albeit at the point of use.

There is a definite benefit to your suggestion that it's explicit, rather than magical, behind-the-scenes switch. My initial reaction was Meh, but sitting more with it, it bodes well with all the other .to switches elsewhere in MP code as the rest aren't magical. So the code reader might be puzzled at how come these aren't being switched.

The more I'm thinking about it, the more I think future pytorch might find a way to not need to manually move inputs to devices and not just with MP. It should figure it all out based on params which already know which devices they are on. And then it will be all magical.

maybe the best way is indeed to use a decorator function here.

In case it wasn't clear, my follow up suggested not to use a decorator since you decided not to use it, but an even more efficient way, which would only impact MP code

Let's see what @LysandreJik thinks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for waiting for my input. I would prefer to go down the explicit road rather than the magical one.

Patrick's proposition seems like the most understandable to me, while keeping all the features you proposed @stas00.

The _call__mp name should probably be made more explicit, however, as right now it can't be understood from its name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicit calls it is then, works for me.

wrt _call_mp - yeah, it has already been renamed in the Bart goes MP PR - there are a lot of moving parts. And a lot of things will be renamed at the end - these are just all temp names anyway.

@alexorona
Copy link
Contributor

alexorona commented Dec 30, 2020

This is great progress, @stas00! From my perspective, to create a general way of doing model parallelism, we need four things:

  • a format for device_map that can be used on any model
  • device_map and model_parallel need to be attributes on all models, probably by assigning them to PreTrainedModel
  • parallelize() and deparallelize() should be on all models, again probably by assigning them to PreTrainedModel
  • changes to the forward methods need to be abstracted if at all possible (this is by far the most challenging)

This PR makes a lot of progress, the strongest of which is a potential abstraction/simplification of the changes to the forward method. Not sure if a decorator is the solution. @LysandreJik will have that insight when he's back. I like the suggestion by @patrickvonplaten that it's instead a two line implementation if self.model_parallel instead of a decorator. But the BIG thing is if most or all of the code in the forward method can be replaced with with something like:

if self.model_parallel:
        hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value = _call__mp(hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value)

If we can get that right, it might turn model parallelism from a day or weekend project per model into something that takes a few minutes. Much more scalable and sustainable.

Supporting non-sequential GPUs could be more trouble than its worth -- not entirely sure on this, it's just my instincts. With the billion + parameter models that we're dealing with -- and all indications are that it's only getting bigger going forward -- it's pretty fair to say that most workflows in enterprise and research will be:

  1. develop locally on a machine with one or maybe two GPUs on a small sized version of a model, and

  2. train a final model on a cloud instance or cluster with multiple identical GPUs.

Sequential hand-offs between GPUs will be the norm in cases like that, which I think are going to be most of them.

The other thing I worry about is a challenge with PyTorch 1.5 and 1.6 model parallelism behavior. The seemingly redundant clauses and set_device statements are there to prevent PyTorch's inferential logic from moving modules or inputs around after .to() assignments have been called. It's very annoying. I don't know if it's fixed in 1.7. You'll notice that output layers like the lm_head are always on the first device instead of the last device. A more logical workflow would have the embedding layers on the first device and the output layers on the last device.

I got that to work just fine in forward passes, but I must've tried 10 different ways to get it to behave in backprop before conceding that for whatever reason PyTorch's quantum device superposition just wouldn't allow it. So the output layers are on the same device as the embedding layers. You'd think that matters for load balance between GPUs, and it does -- for gpt2-xl. But since we're practically limited in most situations to PyTorch's 8 GPU per machine preference (inherited from CUDA), by the time you're at 3 billion parameters the embedding and lm_head layers are so small in comparison to the attention blocks that it doesn't matter that they're both on the first device, and a custom device_map solves the problem for cases where that matters. The implementation implies that there is an extra hand-off or two of a large tensor between GPUs, but I don't think having a perfectly optimal setup will save even 10% on training time. Happy to be proven wrong on this though. What will save a TON of time and $$$ though is deepspeed integration.

I got t5-11b with 1024 tokens to train quickly on the new p4 instance AWS released last month with its epic 320 GB of GPU memory so I was like "ok fine whatever... that's pretty good".

@stas00
Copy link
Contributor Author

stas00 commented Dec 30, 2020

That's awesome, @alexorona. Do continue to share your insights from the frontier!

Let's wait for @LysandreJik to come back to plan ahead and meanwhile I will experiment with Bart.

The seemingly redundant clauses and set_device statements are there to prevent PyTorch's inferential logic from moving modules or inputs around after .to() assignments have been called. It's very annoying. I don't know if it's fixed in 1.7.

Oh, so glad you flagged that. Would it be enough to run the existing parallel tests with pt-1.5, pt-1.6 to detect these failures?

I'm developing on pt-nightly since rtx-30* work only there (well 1.7.1 should be usable too, but mainly waiting for cuda-11.2 support in pytorch, which is again pt-nightly - won't be in 1.7.x). But it means I can't use it with older pt versions.

But since we have to support pt-1.4, I will then put set_device back as you had them originally. But this time let's add specific comments why there are there, otherwise someone like myself will think they are some left-overs from earlier experimentation and swipe them away.

Actually, I think we should have a design document where we explain why this or that is done. Rather than make a lot of noise in the model files. A developer-oriented doc.

The set_device was just one thing, right? Or have I naively nuked any other essentials?

Thanks again!

@stas00
Copy link
Contributor Author

stas00 commented Dec 30, 2020

@alexorona, one more question. If pt-1.7+ removes the need for jumping through hoops, as you're suggesting older versions have all kinds of issues, perhaps it'd be a reasonable approach to make MP in transformers require pt-1.7?

If and when you get some time could you please test if what wasn't working in pt < 1.7 works in pt-1.7? And if not - perhaps we need to file some Issues with pytorch if there are bugs to be solved.

Thank you.

@alexorona
Copy link
Contributor

@stas00 Will try to do so, but in the middle of moving so I don't think I'll get to this until the end of January at the soonest. The team would have to make the call about only support model parallelism for PyTorch >= 1.7.0 if it won't work on earlier versions. I would be very tempted to support that idea, but don't have enough usage information to know what the impact would be.

@stas00
Copy link
Contributor Author

stas00 commented Jan 2, 2021

I guess once everybody is back next week we can start having some discussion with the HF team.

Have an easy move!

@alexorona
Copy link
Contributor

@stas00 Yeah, should be able to get some input when everyone is back. In the meantime, I'm still not sure on the final form of the device_map. There are two issues left to work out:

  1. Some models don't have decoder architectures
  2. No ability to map embeddings and output layers (always on first device), which might be just fine. I think most output layers and embeddings are going to be comparatively small to attention blocks going forward, but we should confirm that. We are allowing people to create custom a device_map that should enable them to get around any potential situations where the first device is becoming overloaded.

To confirm, this looks good for decoder architectures:

device_map = {
	encoder: {
			0: [0, 1, 2, 3, 4, 5],
			1: [6, 7, 8, 9, 10, 11]
			},
	decoder: {
			2: [0, 1, 2],
			3: [3, 4, 5]
			}
}

Maybe we use the keys to map to the attribute? In gpt2, self.h contains the attention blocks, so:

device_map = {
	h: {
		0: [0, 1, 2, 3, 4, 5],
		1: [6, 7, 8, 9, 10, 11]
		}
}

In trying to generalize parallelize(), we still need access to list of all modules. For example, in GPT2LMHeadModel, we would need to know: self.lm_head, self.transformer.h, self.transformer.wte, self.transformer.wpe and self.transformer.ln_f.

@stas00
Copy link
Contributor Author

stas00 commented Jan 2, 2021

I haven't looked into gpt2, yet. t5 and bart are very similar structure-wise. We probably need to map out all the different archs transformers has and then generalize.

What is emerging so far is that the device map might have various keys, none required, and each model architecture will have:

  1. its required keys
  2. its own default map generator - so that the user doesn't have to provide one and overtime it can be improved to have smarts to create a balanced map based on the "insider" information.

So if some architectures need to explicitly manage the mapping of non-block/layers, rather than just assigning them by default on the "main_device", because they are significantly big, they could do that too. Otherwise, leave the main_device to all the "smallish-fish" and use the other devices for "the big fish" if that makes sense. The main advantage of this "lazy" approach is that there is less device-hopping and less code needed to match the hopping.

@alexorona
Copy link
Contributor

alexorona commented Jan 2, 2021

Yes, that's right.

So it turns out the self._modules attribute has all of the modules. To move parallelize() to PreTrainedModel, I think all we need is a per-model module_map object to map between the device_map and the model placements. With a little work, we might be able to reduce making a model parallel to:

  1. Adding a few lines of code in the forward method per your work
  2. Modifying the validation function to check for errors in a custom device_map
  3. Creating a module_map dictionary for that model and adding it to the get_module_map() function

We can embed special placement rules where non-attention block modules need to be on the same device as another module by creating a tuple in module_map['dependent_modules']:

# Device map for GPT2LMHead. T5 would have 'encoder', 'decoder' as keys instead of 'h' and validate_device_map would
# check to see if the device_map has the right keys.
device_map = {
	'h': {
			0: [0, 1, 3, 4],
			1: [5, 6, 7, 8],
			2: [9, 10, 11, 12]
	}
}


class PreTrainedModel():
...
	# Probably use get_model_map(), but just to make it simple:
	self.module_map = {
		'h': self._modules['transformer'].h,
		'embeddings': [
					self._modules['transformer'].wte,
					self._modules['transformer'].wpe
					],
		'dependent_modules': [
					(
						self._modules['transformer'].ln_f], 
						model._modules['transformer'].h[-1],
					),
					(
						self._modules['lm_head'],
						self._modules['transformer'].wte
					)
				]
	}

def parallelize(self, device_map = None):
	
	self.device_map = device_map

	# validate_device_map extended to check for valid keys for model
	
	...

	# Set all embeddings to first device
	if 'embeddings' in self.module_map:
		for layer in self.module_map['embeddings'].items():
			layer.to(self.first_device)


	# Assign attention blocks to the appropriate device.
	for module_group, group_map in self.device_map.items():
	    for device, layers in group_map.items():
	        for layer in layers:
	            self.module_map[module_group][layer].block_parallelize(f"cuda:{device}")


	# Some modules should always be on the same device as another module. We can express 
	# this as a tuple pair where tuple[0] needs to be on tuple[1]
	if 'dependent_modules' in self.module_map:
		 for i in self.module_map['dependent_modules']:
		 	i[0].to(i[1].device)


def block_deparallelize(self):
self.to("cpu")
self.model_parallel = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need model_parallel as an attribute on the module itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if we use the automatic input remapper, since there is no other way to do if self.model_parallel but we surely will eventually remove anything that is not being used.

Actually do we really need block_deparallelize? What would be the practical use?

self.last_device = "cuda:" + str(max(self.device_map.keys()))
device_map = init_device_map(len(self.block), device_map)
self.first_device = f"cuda:{ list(device_map.keys())[0] }"
self.last_device = f"cuda:{ list(device_map.keys())[-1] }"
Copy link
Contributor

@alexorona alexorona Jan 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should instead put some additional work into get_device_map so it just infers based type(model):

self.device_map = get_device_map(model) if device_map is None else device_map
validate_device_map(self.device_map)

And keep the old language on first_device and last_device so we can support putting part of the model on "cpu".

self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys())) 

Copy link
Contributor Author

@stas00 stas00 Jan 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not even sure we even need first/last devices.

But yes, I think eventually each model will have to implement its own default device map. see the work I started in #9384 make_default_device_map in model_parallel_utils.py

So there will be no need to infer, since the model will just have a method that will do that.

Of course, we don't have to re-implement it for each model if it's the similar archs, so perhaps we will end up with a set of default maps and each model will pick and customize the one it needs. e.g. t5 and bart are almost the same with the exception of the name of the module list of the layers/blocks.

wrt to some layers on cpu - you believe it'd be necessary to support that? if so then perhaps it should be handled transparently along with the gpu ids. and no special case required for it, so that you could use

{ 'cpu': [1,2],
  0:     [3, 4, 5],
}

that's actually a cool thing since then one could play with MP with just one gpu.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opposition to update from get_device_map to init_device_map, but it would be nice to keep CPU support here if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a rename, it's a change in functionality - where the default and check are merged into one function to avoid pointless code repetition - doesn't contribute to easier understanding.

and yes, point taken on supporting cpu.

I propose the first stage is to polish the API for gpus-only and then definitely include cpu. as it'd be helpful for developing/debugging MP with just one gpu.

I will start making todo items.

@stas00
Copy link
Contributor Author

stas00 commented Jan 2, 2021

All, awesome suggestions that should be looked at next once the current work has been merged.

I'm going to wait implementing anything new, since there are already too many partial PRs that need to be carefully merged and rebased and once that is done we can do another round of generalization integrating your suggestions.

@stas00 stas00 added the Model Parallel Model Parallelilsm Implementations label Jan 2, 2021
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice changes, LGTM! Just one nit regarding CPU support.

@@ -773,41 +784,34 @@ def __init__(self, config, embed_tokens=None):
self.dropout = nn.Dropout(config.dropout_rate)

self.init_weights()
# Model parallel

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for waiting for my input. I would prefer to go down the explicit road rather than the magical one.

Patrick's proposition seems like the most understandable to me, while keeping all the features you proposed @stas00.

The _call__mp name should probably be made more explicit, however, as right now it can't be understood from its name.

self.last_device = "cuda:" + str(max(self.device_map.keys()))
device_map = init_device_map(len(self.block), device_map)
self.first_device = f"cuda:{ list(device_map.keys())[0] }"
self.last_device = f"cuda:{ list(device_map.keys())[-1] }"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opposition to update from get_device_map to init_device_map, but it would be nice to keep CPU support here if possible.

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2021

So if there is no objection, I will merge this one, and then start integrating with #9384, which is ahead functionality-wise - so I want to sync the two, switching t5 to the improved version of MP backend. I will implement the suggestions in that new PR.

@stas00
Copy link
Contributor Author

stas00 commented Jan 7, 2021

As we have discovered the original PR didn't make t5 work with trainer. I have just fixed that in the last commit here, bringing some goodies over from the Bart MP PR.

So this now works:

export BS=20; rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0  ./finetune_trainer.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 200 --n_val 200 --n_test 200 --fp16 --save_steps 2 --model_parallel

But! while it's fine in the training stage, it's 10x slower on eval than w/o --model_parallel

@kznmft
Copy link

kznmft commented Jan 26, 2021

Hello @stas00 kudos to all the hard work you do, especially around continuing the ambitious work around supporting parallelism.

Interested in doing some inference with the t5-11b model variant.
Can you provide some insights on how many gpus would be needed to achieve that?

I tried this branch with 8xV100 (16gb) on GCE.
All good while I created the model and called parallelize, but got a out of memory error on inference step when moving inputs to the first gpu device.

Let me know if I have a wrong mental model about achieving this. Thanks again!

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2021

Thank you for the kind words, @kznmft!

Please have a look at #9765
which implements a very inefficient in my opinion but nevertheless working pipeline parallelism on t5, which should be superior to this naive implementation, speed-wise but it's not quite there yet. Please read the first post carefully for all the details. and you can see the follow up comments with the experiments that have been done. So 4x40gb A100s gpus weren't enough for t5-11b in initial experiments. But 5-6 of those probably should be enough.

I finally got access to a machine with 4 gpus just now, so I'm going to start looking at implementing 2D parallelism - using Pipeline with DeepSpeed ZeRO-DP, so I will post news once I get something working.

Subscribe to watch #9765 and I most likely will update that PR with new info or a link to a new PR once I have something working.


I tried this branch with 8xV100 (16gb) on GCE.
All good while I created the model and called parallelize, but got a out of memory error on inference step when moving inputs to the first gpu device.

But you're not telling me the device map you were using. You need to spread out the layers over the 8 gpus, have you done it? unless you were relying on the default map which should spread things out.

The problem is that it doesn't take into an account that gpu 0 is always overtaxed, so I'd always try a few layers less on the first gpu 0. And then watch nvidia-smi (and later we will have better tools) to see that you get each GPU getting a somewhat equal memory allocation.

But if 4x40 couldn't fit it, I doubt that 8x16 will.

Remember in t5-11b you have 45GB of params, plus optimizer states plus gradients.

Also probably need to try to use a more lean optimizer, say Adam instead of AdamW which needs more memory.

@huggingface huggingface deleted a comment from github-actions bot Apr 15, 2021
@stas00 stas00 self-assigned this Apr 15, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 15, 2021
@stas00
Copy link
Contributor Author

stas00 commented Jun 4, 2021

too long. closing.

@stas00 stas00 closed this Jun 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Model Parallel Model Parallelilsm Implementations WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants