Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating Riemannian Preconditioner #1807

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

fangzhaozhang
Copy link

@fangzhaozhang fangzhaozhang commented May 28, 2024

Paper link: https://arxiv.org/pdf/2402.02347
This is an attempt to integrate a special optimizer for LoRA training to current huggingface peft codebase. We follow structure in PR to add LoRA+ (#1509).

@fangzhaozhang fangzhaozhang marked this pull request as draft May 28, 2024 23:04
@fangzhaozhang
Copy link
Author

we have added a test file in peft/tests/riemannian_test.py which uses the new optimizer for training a LLM using trainer class.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for creating this draft PR ot add Riemannian AdamW. I did a first review but haven't looked at the exact implementation details and compared to the paper yet. I added some comments which, if addressed, will help me better understand what's going on.

Apart from the code comments I added, I have some more general comments:

  1. This PR contains the code from the lora+ PR. Please remove it.
  2. Could you please run make style?
  3. If some of this code is copied over from https://github.com/pilancilab/Riemannian_Preconditioned_LoRA or elsewhere, please add a comment with a reference.
  4. You added a test but it does not have the form of a proper unit test. I think it would be better to rewrite this a bit and add it to the examples/ directory, as it's more akin to an example.
  5. Regarding proper unit tests, check out the tests from the lora+ PR. LMK if you need more guidance.

I know that overall, this seems to be a lot of work, but I'm sure we can get this into a good shape. If you have any questions, don't hesitate to ask.

model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the same indentation and syntax as the other parameters. Also, let's add docs for reg.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, indentation is still wrong. It should be:

        optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
        lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
        reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only

- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
"""

"""TEST VERSION FOR ADAMW"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For code comments, use # and not strings.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""

"""TEST VERSION FOR ADAMW"""
assert optimizer_cls.__name__=='AdamW', 'TEST version only supports AdamW optimizer'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use assert in code (only tests). Here, it is better to raise a TypeError. Also, I wonder: does the class have to be AdamW or can it be a subclass? If the latter, you can change the check to: if not issubclass(optimizer_cls, torch.optim.AdamW).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for name, param in model.named_parameters():
if not param.requires_grad:
continue
# print(name, param.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Creates a Riemmanian optimizer.
Implementation: https://github.com/pilancilab/Riemannian_Preconditioned_LoRA
Reference: https://arxiv.org/pdf/2402.02347
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that this only works for LoRA.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


for group in self.param_groups:
if group['is_lora']:
for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to understand this: I think we iterate over pairs of lora_A and lora_B, which is why we have the zip and the [::2]. Is that it?

I wonder if we can make the assumption that pairs of lora_A and lora_B are always following consecutively. E.g. what would happen if we have use_dora=True, could it happen that we now suddenly have triplets?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding is correct. This is exactly what I'm concerned/worried about. Since in our paper, for each lora pair (lora_A, lora_B), what we do is to use grad(lora_A)@ inverse(lora_B'lora_B) in place of vanilla grad(lora_A). For our paper's results, we just test and observe this changed gradient is better than vanilla gradient with respect to loss minimization. Moreover, since lora_B'lora_B is of shape r*r, then inverse(lora_B'lora_B) is expected to not take long, especially for small r. Our original implementation is basic and we just iterate like [::2].

In its development, I'm not sure how to pair up (lora_A,lora_B) in an error-free way, as you mentioned, for DoRA, since we also have the magnitude term, I feel it's better for us to actually got these pairs by matching the name, i.e., "layer1_attentionq_lora_A" and "layer1_attentionq_lora_B"? This is also better for order keeping since I feel we cannot assume each lora_A is followed by its corresponding lora_B.

Moreover, the [::2] indeed takes long compared to simple AdamW loop, thus in addition to the inverse operator, we actually also suffer from the loop runtime overhead. Shall we indeed keep some dict for lora_A and lora_B parameters respectively and directly query the corresponding value by index when needed?

for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]:
grad = p1.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")

Not sure if it makes sense to suggest SparseAdam here.

reg_I = self.defaults['reg']*torch.eye(min(p2.shape)).to(p2.device)
scaler = torch.inverse(scaler@scaler.T+reg_I) if p2.shape[0]<p2.shape[1] \
else torch.inverse(scaler.T@scaler+reg_I)
assert scaler.shape[0]==min(p2.data.shape), 'wrong dimension'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, let's not use assert but raise a proper error here (ValueError with a useful message).

else torch.inverse(scaler.T@scaler+reg_I)
assert scaler.shape[0]==min(p2.data.shape), 'wrong dimension'
except:
print('invalid condition')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

if group["weight_decay"] > 0.0:
p2.add_(p2, alpha=(-group["lr"] * group["weight_decay"]))

else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code path normal AdamW or are there changes in here too? Adding a comment would be helpful.

@BenjaminBossan
Copy link
Member

@fangzhaozhang do you still plan on working on this?

@fangzhaozhang
Copy link
Author

fangzhaozhang commented Jun 28, 2024 via email

@fangzhaoz
Copy link

I'm back on the implementation. Thanks so much for your detailed comments. With respect to the general points,

  1. I've removed lora plus code
  2. I've run make style
  3. I've added reference link to our original implementation
  4. I've moved the prior test to examples/riemannian_lora and I rewrite a test in tests/test_riemannian_lora.py follow lora plus's tests/test_loraplus_helper.py. Lmk whether this is the desired unit test form.

I've also fixed small issues such as code comments, function name, etc. as suggested in the comments above. However, I'm not very sure about the following point:

  1. Our current implementation is a rewrite of transformer's AdamW https://github.com/huggingface/transformers/blob/v4.42.0/src/transformers/optimization.py#L558, shall we instead follow torch.optim.AdamW implementation, which is more complete though complex?
  2. Our method has a pretty different logic from lora plus, lora plus serves as an optimizer wrapper by just changing the learning rate setting, we are more close to writing a new optimizer customized to LoRA instead since we are changing the optimizer's inner workflow. lora plus is integrable to all optimizers such as Adam,AdamW,Adagrad, etc., our paper only described modifications to SGD and AdamW instead. Thus I'm not sure whether it's best to make our method appear in peft/optimizers in parallel with lora plus, it feels more natural to get our optimizer in parallel with AdamW implementation or just pass in a parameter like lora=True to transformer's AdamW in order to switch to our method. Besides, our method is not directly applicable to bitsandbytes and other quantized form since torch.inverse() is only compliant with certain dtype. Then shall we also do a dtype conversion before and after we compute torch.inverse() to make it more general?
  3. The iteration method also confuses me, shall we change to dict of lora_A/lora_B and query them by indexing compared to current [::2] setting?

Would be glad to hear from your feedback/suggestions on the above questions.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the updates. We're getting closer but there are still a few areas that need to be improved.

Also, note that the LoRA+ PR is now moved to #1915 with a few changes.

Thus I'm not sure whether it's best to make our method appear in peft/optimizers in parallel with lora plus, it feels more natural to get our optimizer in parallel with AdamW implementation or just pass in a parameter like lora=True to transformer's AdamW in order to switch to our method

Since this is very PEFT specific, I think the best fit is indeed here. It would be quite hard to convince transformers to add this very specific change.

2. Besides, our method is not directly applicable to bitsandbytes and other quantized form since torch.inverse() is only compliant with certain dtype. Then shall we also do a dtype conversion before and after we compute torch.inverse() to make it more general?

If you can implement a version that works with quantized weights, that would be great. If not, that's also okay, but then let's document this clearly.

Comment on lines +1 to +5
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all

# coding=utf-8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines can be removed. At the bottom of the file, add __all__ = ["create_riemannian_optimizer"]

# module, but to preserve other warnings. So, don't check this module at all

# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023-present the HuggingFace Inc. team.
# Copyright 2024-present the HuggingFace Inc. team.

model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, indentation is still wrong. It should be:

        optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
        lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
        reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only

Comment on lines +34 to +35
if not issubclass(optimizer_cls, torch.optim.AdamW):
raise TypeError("TEST version only supports AdamW optimizer")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the optimizer_cls argument is not actually except to raise an error, how about removing it completely?

def create_riemannian_optimizer(
model: PeftModel,
optimizer_cls: type[Optimizer],
optimizer_kwargs: dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you probably took this from the LoRA+ PR, let me refer to the comment I put there:

A suggestion: Let's remove optimizer_kwargs and just add **kwargs. IMO, that makes calling this function easier, as we can use create_riemannian_optimizer(..., weight_decay=1e-3) instead of create_riemannian_optimizer(..., optimizer_kwargs={..., "weight_decay": 1e-3}). And since lr is not optional, let's make this a normal arg of create_riemannian_optimizer.

Comment on lines +139 to +141
for group in self.param_groups:
if group["is_lora"]:
for p1, p2 in list(zip(group["params"], group["params"][1:]))[::2]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in the other comment, this is indeed error prone. For this, the logic here:

https://github.com/huggingface/peft/pull/1807/files#diff-4730f831ea49f19ef126ffa6d712865c57a477585e4098b74acb6026d3056d5aR46-R47

should be improved. I think it's better if we create two separate groups for lora_A and lora_B. After the loop there, let's also check that both groups have the same length and that the length is > 0. In the optimizer_grouped_parameters, we can set "is_lora_A": True and "is_lora_B": True accordingly.

After making this change, the line here could be simplified to:

# this works because there is exactly one lora_A and one lora_B group
lora_A_params = next(group for group in self.param_groups if group["is_lora_A"])
lora_B_params = next(group for group in self.param_groups if group["is_lora_B"])
for p1, p2 in zip(lora_A_params, lora_B_params):

if p2.shape[0] < p2.shape[1]
else torch.inverse(scaler.T @ scaler + reg_I)
)
assert scaler.shape[0] == min(p2.data.shape), "wrong dimension"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use assert, instead raise a proper ValueError with a helpful message.

if p1.shape[0] < p1.shape[1]
else torch.inverse(scaler.T @ scaler + reg_I)
)
assert scaler.shape[0] == min(p1.data.shape), "wrong dimension"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use assert, instead raise a proper ValueError with a helpful message.

else torch.inverse(scaler.T @ scaler + reg_I)
)
assert scaler.shape[0] == min(p2.data.shape), "wrong dimension"
except RuntimeError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is needed? Could we instead check the condition and do something like if valid_condition: ... else: scaler = None. Let's completely avoid printing messages.

)
assert scaler.shape[0] == min(p1.data.shape), "wrong dimension"
except RuntimeError:
print("invalid condition")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is needed? Could we instead check the condition and do something like if valid_condition: ... else: scaler = None. Let's completely avoid printing messages.

@kallewoof
Copy link
Contributor

Cool! We should ensure that we add documentation clarifying whether this works together with LoRA+ or whether the two are mutually exclusive for some reason.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

@fangzhaozhang Do you still plan on finishing this PR?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants