Skip to content

Commit

Permalink
Add support for HVI-CIDNet (#271)
Browse files Browse the repository at this point in the history
* Add support for HVI-CIDNet

* Shorter link

* Fix category
  • Loading branch information
RunDevelopment committed May 29, 2024
1 parent 2f71a10 commit cb2f034
Show file tree
Hide file tree
Showing 11 changed files with 663 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Spandrel currently supports a limited amount of network architectures. If the ar
#### Low-light Enhancement

- [RetinexFormer](https://github.com/caiyuanhao1998/Retinexformer) | [Models](https://drive.google.com/drive/folders/1ynK5hfQachzc8y96ZumhkPPDXzHJwaQV?usp=drive_link)
- [HVI-CIDNet](https://github.com/Fediory/HVI-CIDNet) | [Models](https://github.com/Fediory/HVI-CIDNet/#weights-and-results-)

(All architectures marked with a `+` are only part of `spandrel_extra_arches`.)

Expand Down
2 changes: 2 additions & 0 deletions libs/spandrel/spandrel/__helpers/main_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DnCNN,
DRUNet,
FFTformer,
HVICIDNet,
KBNet,
LaMa,
MixDehazeNet,
Expand Down Expand Up @@ -82,4 +83,5 @@
ArchSupport.from_architecture(ESRGAN.ESRGANArch()),
ArchSupport.from_architecture(PLKSR.PLKSRArch()),
ArchSupport.from_architecture(RetinexFormer.RetinexFormerArch()),
ArchSupport.from_architecture(HVICIDNet.HVICIDNetArch()),
)
94 changes: 94 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing_extensions import override

from spandrel.util import KeyCondition

from ...__helpers.model_descriptor import (
Architecture,
ImageModelDescriptor,
ModelTiling,
SizeRequirements,
StateDict,
)
from .arch.cidnet import CIDNet as HVICIDNet


class HVICIDNetArch(Architecture[HVICIDNet]):
def __init__(self) -> None:
super().__init__(
id="HVICIDNet",
name="HVI-CIDNet",
detect=KeyCondition.has_all(
"HVE_block0.1.weight",
"HVE_block1.prelu.weight",
"HVE_block1.down.0.weight",
"HVE_block3.down.0.weight",
"HVD_block3.prelu.weight",
"HVD_block3.up_scale.0.weight",
"HVD_block3.up.weight",
"HVD_block1.up.weight",
"HVD_block0.1.weight",
"IE_block0.1.weight",
"IE_block1.prelu.weight",
"IE_block1.down.0.weight",
"ID_block1.up.weight",
"ID_block0.1.weight",
"HV_LCA1.gdfn.project_in.weight",
"HV_LCA1.gdfn.dwconv.weight",
"HV_LCA1.gdfn.dwconv1.weight",
"HV_LCA1.gdfn.dwconv2.weight",
"HV_LCA1.gdfn.project_out.weight",
"HV_LCA1.norm.weight",
"HV_LCA1.ffn.temperature",
"HV_LCA1.ffn.q.weight",
"HV_LCA1.ffn.q_dwconv.weight",
"HV_LCA1.ffn.project_out.weight",
"HV_LCA2.gdfn.project_in.weight",
"HV_LCA6.gdfn.project_in.weight",
"I_LCA1.gdfn.project_in.weight",
"I_LCA6.ffn.project_out.weight",
"trans.density_k",
),
)

@override
def load(self, state_dict: StateDict) -> ImageModelDescriptor[HVICIDNet]:
channels = [36, 36, 72, 144]
heads = [1, 2, 4, 8]
norm = False

channels = [
state_dict["HVE_block1.down.0.weight"].shape[1],
state_dict["HVE_block1.down.0.weight"].shape[0],
state_dict["HVE_block2.down.0.weight"].shape[0],
state_dict["HVE_block3.down.0.weight"].shape[0],
]

heads = [
1, # unused
state_dict["HV_LCA1.ffn.temperature"].shape[0],
state_dict["HV_LCA2.ffn.temperature"].shape[0],
state_dict["HV_LCA3.ffn.temperature"].shape[0],
]

norm = "HVE_block1.norm.weight" in state_dict

model = HVICIDNet(
channels=channels,
heads=heads,
norm=norm,
)

return ImageModelDescriptor(
model,
state_dict,
architecture=self,
purpose="Restoration",
tags=[],
supports_half=False, # TODO: verify
supports_bfloat16=True,
scale=1,
input_channels=3, # hard-coded
output_channels=3, # hard-coded
size_requirements=SizeRequirements(multiple_of=8),
tiling=ModelTiling.DISCOURAGED,
)
135 changes: 135 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/arch/HVI_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from math import pi

import torch
import torch.nn as nn


class RGB_HVI(nn.Module):
def __init__(self):
super().__init__()
self.density_k = torch.nn.Parameter(
torch.full([1], 0.2)
) # k is reciprocal to the paper mentioned
self.gated = False
self.gated2 = False
self.alpha = 1.0
self.this_k = 0

def HVIT(self, img):
eps = 1e-8
device = img.device
dtypes = img.dtype
hue = (
torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
)
value = img.max(1)[0].to(dtypes)
img_min = img.min(1)[0].to(dtypes)
hue[img[:, 2] == value] = (
4.0
+ ((img[:, 0] - img[:, 1]) / (value - img_min + eps))[img[:, 2] == value]
)
hue[img[:, 1] == value] = (
2.0
+ ((img[:, 2] - img[:, 0]) / (value - img_min + eps))[img[:, 1] == value]
)
hue[img[:, 0] == value] = (
0.0
+ ((img[:, 1] - img[:, 2]) / (value - img_min + eps))[img[:, 0] == value]
) % 6

hue[img.min(1)[0] == value] = 0.0
hue = hue / 6.0

saturation = (value - img_min) / (value + eps)
saturation[value == 0] = 0

hue = hue.unsqueeze(1)
saturation = saturation.unsqueeze(1)
value = value.unsqueeze(1)

k = self.density_k
self.this_k = k.item()

color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
cx = (2.0 * pi * hue).cos()
cy = (2.0 * pi * hue).sin()
X = color_sensitive * saturation * cx
Y = color_sensitive * saturation * cy
Z = value
xyz = torch.cat([X, Y, Z], dim=1)
return xyz

def PHVIT(self, img):
eps = 1e-8
H, V, I = img[:, 0, :, :], img[:, 1, :, :], img[:, 2, :, :] # noqa: E741

# clip
H = torch.clamp(H, -1, 1)
V = torch.clamp(V, -1, 1)
I = torch.clamp(I, 0, 1) # noqa: E741

v = I
k = self.this_k
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
H = (H) / (color_sensitive + eps)
V = (V) / (color_sensitive + eps)
H = torch.clamp(H, -1, 1)
V = torch.clamp(V, -1, 1)
h = torch.atan2(V, H) / (2 * pi)
h = h % 1
s = torch.sqrt(H**2 + V**2)

if self.gated:
s = s * 1.3

s = torch.clamp(s, 0, 1)
v = torch.clamp(v, 0, 1)

r = torch.zeros_like(h)
g = torch.zeros_like(h)
b = torch.zeros_like(h)

hi = torch.floor(h * 6.0)
f = h * 6.0 - hi
p = v * (1.0 - s)
q = v * (1.0 - (f * s))
t = v * (1.0 - ((1.0 - f) * s))

hi0 = hi == 0
hi1 = hi == 1
hi2 = hi == 2
hi3 = hi == 3
hi4 = hi == 4
hi5 = hi == 5

r[hi0] = v[hi0]
g[hi0] = t[hi0]
b[hi0] = p[hi0]

r[hi1] = q[hi1]
g[hi1] = v[hi1]
b[hi1] = p[hi1]

r[hi2] = p[hi2]
g[hi2] = v[hi2]
b[hi2] = t[hi2]

r[hi3] = p[hi3]
g[hi3] = q[hi3]
b[hi3] = v[hi3]

r[hi4] = t[hi4]
g[hi4] = p[hi4]
b[hi4] = v[hi4]

r[hi5] = v[hi5]
g[hi5] = p[hi5]
b[hi5] = q[hi5]

r = r.unsqueeze(1)
g = g.unsqueeze(1)
b = b.unsqueeze(1)
rgb = torch.cat([r, g, b], dim=1)
if self.gated2:
rgb = rgb * self.alpha
return rgb
133 changes: 133 additions & 0 deletions libs/spandrel/spandrel/architectures/HVICIDNet/arch/LCA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import torch.nn as nn
from einops import rearrange

from .transformer_utils import LayerNorm


# Cross Attention Block
class CAB(nn.Module):
def __init__(self, dim, num_heads, bias):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.q_dwconv = nn.Conv2d(
dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias
)
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
self.kv_dwconv = nn.Conv2d(
dim * 2,
dim * 2,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 2,
bias=bias,
)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

def forward(self, x, y):
_, _, h, w = x.shape

q = self.q_dwconv(self.q(x))
kv = self.kv_dwconv(self.kv(y))
k, v = kv.chunk(2, dim=1)

q = rearrange(q, "b (head c) h w -> b head c (h w)", head=self.num_heads)
k = rearrange(k, "b (head c) h w -> b head c (h w)", head=self.num_heads)
v = rearrange(v, "b (head c) h w -> b head c (h w)", head=self.num_heads)

q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)

attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)

out = attn @ v

out = rearrange(
out, "b head c (h w) -> b (head c) h w", head=self.num_heads, h=h, w=w
)

out = self.project_out(out)
return out


# Intensity Enhancement Layer
class IEL(nn.Module):
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
super().__init__()

hidden_features = int(dim * ffn_expansion_factor)

self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)

self.dwconv = nn.Conv2d(
hidden_features * 2,
hidden_features * 2,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features * 2,
bias=bias,
)
self.dwconv1 = nn.Conv2d(
hidden_features,
hidden_features,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features,
bias=bias,
)
self.dwconv2 = nn.Conv2d(
hidden_features,
hidden_features,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features,
bias=bias,
)

self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

self.Tanh = nn.Tanh()

def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x1 = self.Tanh(self.dwconv1(x1)) + x1
x2 = self.Tanh(self.dwconv2(x2)) + x2
x = x1 * x2
x = self.project_out(x)
return x


# Lightweight Cross Attention
class HV_LCA(nn.Module):
def __init__(self, dim, num_heads, bias=False):
super().__init__()
self.gdfn = IEL(dim) # IEL and CDL have same structure
self.norm = LayerNorm(dim)
self.ffn = CAB(dim, num_heads, bias)

def forward(self, x, y):
x = x + self.ffn(self.norm(x), self.norm(y))
x = self.gdfn(self.norm(x))
return x


class I_LCA(nn.Module):
def __init__(self, dim, num_heads, bias=False):
super().__init__()
self.norm = LayerNorm(dim)
self.gdfn = IEL(dim)
self.ffn = CAB(dim, num_heads, bias=bias)

def forward(self, x, y):
x = x + self.ffn(self.norm(x), self.norm(y))
x = x + self.gdfn(self.norm(x))
return x
Loading

0 comments on commit cb2f034

Please sign in to comment.