Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Norm Module to lora ext and add "bias" support #12503

Merged
merged 3 commits into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def calc_scale(self):

return 1.0

def finalize_updown(self, updown, orig_weight, output_shape):
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
Expand All @@ -145,7 +145,10 @@ def finalize_updown(self, updown, orig_weight, output_shape):
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)

return updown * self.calc_scale() * self.multiplier()
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()

return updown * self.calc_scale() * self.multiplier(), ex_bias

def calc_updown(self, target):
raise NotImplementedError()
Expand Down
28 changes: 28 additions & 0 deletions extensions-builtin/Lora/network_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import network


class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)

return None


class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)

self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")

def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)

if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None

return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
64 changes: 55 additions & 9 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import network_ia3
import network_lokr
import network_full
import network_norm

import torch
from typing import Union
Expand All @@ -19,6 +20,7 @@
network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
]


Expand All @@ -31,6 +33,8 @@
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
Expand Down Expand Up @@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory()


def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)

if weights_backup is None:
if weights_backup is None and bias_backup is None:
return

if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)

if bias_backup is not None:
self.bias.copy_(bias_backup)

def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):

def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
Expand All @@ -294,20 +303,27 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

self.network_weights_backup = weights_backup

bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
self.network_bias_backup = bias_backup

if current_names != wanted_names:
network_restore_weights_from_backup(self)

for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
with torch.no_grad():
updown = module.calc_updown(self.weight)
updown, ex_bias = module.calc_updown(self.weight)

if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))

self.weight += updown
if ex_bias is not None and getattr(self, 'bias', None) is not None:
self.bias += ex_bias
continue

module_q = net.modules.get(network_layer_name + "_q_proj", None)
Expand Down Expand Up @@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)


def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)

network_apply_weights(self)

return torch.nn.GroupNorm_forward_before_network(self, input)


def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)

return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)


def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)

network_apply_weights(self)

return torch.nn.LayerNorm_forward_before_network(self, input)


def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)

return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)


def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)

Expand Down
16 changes: 16 additions & 0 deletions extensions-builtin/Lora/scripts/lora_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def before_ui():
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict

if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward

if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict

if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward

if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict

if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward

Expand All @@ -50,6 +62,10 @@ def before_ui():
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict

Expand Down