Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify logic for detecting pixelshuffle #249

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions libs/spandrel/spandrel/architectures/ATD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -114,15 +114,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[ATD]:
upscale = 4
elif "conv_before_upsample.0.weight" in state_dict:
upsampler = "pixelshuffle"
upscale = 1
for i in range(0, 10, 2):
if f"upsample.{i}.weight" not in state_dict:
break
num_feat = state_dict[f"upsample.{i}.weight"].shape[1]

upscale *= math.isqrt(
state_dict[f"upsample.{i}.weight"].shape[0] // num_feat
)
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")
elif "conv_last.weight" in state_dict:
upsampler = ""
upscale = 1
Expand Down
8 changes: 2 additions & 6 deletions libs/spandrel/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -107,11 +107,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DAT]:
resi_connection = "1conv" if "conv_after_body.weight" in state_dict else "3conv"

if upsampler == "pixelshuffle":
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
num_feat = state_dict[f"upsample.{i}.weight"].shape[1]
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, num_feat = get_pixelshuffle_params(state_dict, "upsample")
elif upsampler == "pixelshuffledirect":
num_feat = state_dict["upsample.0.weight"].shape[1]
upscale = int(
Expand Down
21 changes: 2 additions & 19 deletions libs/spandrel/spandrel/architectures/DRCT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand All @@ -13,23 +13,6 @@
from .arch.drct_arch import DRCT


def _get_upscale_pixelshuffle(
state_dict: StateDict, key_prefix: str = "upsample"
) -> int:
upscale = 1

for i in range(0, 10, 2):
key = f"{key_prefix}.{i}.weight"
if key not in state_dict:
break

shape = state_dict[key].shape
num_feat = shape[1]
upscale *= math.isqrt(shape[0] // num_feat)

return upscale


class DRCTArch(Architecture[DRCT]):
def __init__(self) -> None:
super().__init__(
Expand Down Expand Up @@ -105,7 +88,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[DRCT]:

if "conv_last.weight" in state_dict:
upsampler = "pixelshuffle"
upscale = _get_upscale_pixelshuffle(state_dict, "upsample")
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")
else:
upsampler = ""
upscale = 1
Expand Down
13 changes: 7 additions & 6 deletions libs/spandrel/spandrel/architectures/GRL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import torch
from typing_extensions import override

from spandrel.util import KeyCondition, get_scale_and_output_channels, get_seq_len
from spandrel.util import (
KeyCondition,
get_pixelshuffle_params,
get_scale_and_output_channels,
get_seq_len,
)

from ...__helpers.canonicalize import remove_common_prefix
from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict
Expand Down Expand Up @@ -50,18 +55,14 @@ def _get_output_params(state_dict: StateDict, in_channels: int):
upsampler: str
upscale: int

num_out_feats = 64 # hard-coded
if (
"conv_before_upsample.0.weight" in state_dict
and "upsample.up.0.weight" in state_dict
):
upsampler = "pixelshuffle"
out_channels = state_dict["conv_last.weight"].shape[0]

upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample.up"), 2):
shape = state_dict[f"upsample.up.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_out_feats))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample.up")
elif "upsample.up.0.weight" in state_dict:
upsampler = "pixelshuffledirect"
upscale, out_channels = get_scale_and_output_channels(
Expand Down
7 changes: 2 additions & 5 deletions libs/spandrel/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -109,10 +109,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[HAT]:
embed_dim = state_dict["conv_first.weight"].shape[0]

num_feat = state_dict["conv_last.weight"].shape[1]
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample", num_feat)

window_size = int(math.sqrt(state_dict["relative_position_index_SA"].shape[0]))
overlap_ratio = _get_overlap_ratio(
Expand Down
10 changes: 2 additions & 8 deletions libs/spandrel/spandrel/architectures/RGT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -133,13 +133,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[RGT]:
)
break

upscale = 1
for i in range(0, 10, 2):
key = f"upsample.{i}.weight"
if key in state_dict:
shape = state_dict[key].shape
num_feat = shape[1]
upscale *= math.isqrt(shape[0] // num_feat)
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")

split_size = _get_split_size(state_dict)

Expand Down
8 changes: 2 additions & 6 deletions libs/spandrel/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -102,11 +102,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[Swin2SR]:
math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans)
)
else:
num_feat = 64 # hard-coded constant
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")

window_size = int(
math.sqrt(
Expand Down
12 changes: 2 additions & 10 deletions libs/spandrel/spandrel/architectures/SwinIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import nn
from typing_extensions import override

from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from ...__helpers.model_descriptor import (
Architecture,
Expand Down Expand Up @@ -84,15 +84,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SwinIR]:
for _upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle":
upsample_keys = [
x
for x in state_dict
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = state_dict[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
upscale, num_feat = get_pixelshuffle_params(state_dict, "upsample")
elif upsampler == "pixelshuffledirect":
upscale = int(
math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch)
Expand Down
34 changes: 32 additions & 2 deletions libs/spandrel/spandrel/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ def is_square(n: int) -> bool:
)


def get_pixelshuffle_params(
state_dict: Mapping[str, object],
upsample_key: str = "upsample",
default_nf: int = 64,
) -> tuple[int, int]:
"""
This will detect the upscale factor and number of features of a pixelshuffle module in the state dict.

A pixelshuffle module is a sequence of alternating up convolutions and pixelshuffle.
The class of this module is commonyl called `Upsample`.
Examples of such modules can be found in most SISR architectures, such as SwinIR, HAT, RGT, and many more.
"""
upscale = 1
num_feat = default_nf

for i in range(0, 10, 2):
key = f"{upsample_key}.{i}.weight"
if key not in state_dict:
break

tensor = state_dict[key]
# we'll assume that the state dict contains tensors
shape: tuple[int, ...] = tensor.shape # type: ignore
num_feat = shape[1]
upscale *= math.isqrt(shape[0] // num_feat)

return upscale, num_feat


def store_hyperparameters(*, extra_parameters: Mapping[str, object] = {}):
"""
Stores the hyperparameters of a class in a `hyperparameters` attribute.
Expand Down Expand Up @@ -170,9 +199,10 @@ def new_init(self: C, **kwargs):


__all__ = [
"KeyCondition",
"get_first_seq_index",
"get_seq_len",
"get_pixelshuffle_params",
"get_scale_and_output_channels",
"get_seq_len",
"KeyCondition",
"store_hyperparameters",
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
SizeRequirements,
StateDict,
)
from spandrel.util import KeyCondition, get_seq_len
from spandrel.util import KeyCondition, get_pixelshuffle_params, get_seq_len

from .arch.SRFormer import SRFormer

Expand Down Expand Up @@ -76,12 +76,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[SRFormer]:
upscale = 4 # only supported scale
elif "conv_before_upsample.0.weight" in state_dict:
upsampler = "pixelshuffle"

num_feat = 64 # hard-coded constant
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
upscale, _ = get_pixelshuffle_params(state_dict, "upsample")
elif "upsample.0.weight" in state_dict:
upsampler = "pixelshuffledirect"
upscale = int(
Expand Down
Loading