Skip to content

deep image prior #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: test
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e59295a
Fixed broken backup remover
Bardia323 Jun 20, 2022
238ccbf
Made backup removal path cross-platform
Bardia323 Jun 20, 2022
288ab13
Merge pull request #207 from Bardia323/fix_backup
dmarx Jun 20, 2022
ade92ab
modified backup removal to use pathlib
dmarx Jun 20, 2022
4a82e67
Merge pull request #208 from pytti-tools/fix_backup_xplatform
dmarx Jun 20, 2022
24cceee
added proxy property for improved ema
dmarx Jun 13, 2022
a96d081
migrated vqgan image to proxy
dmarx Jun 13, 2022
08bee59
pushed proxy property down to base class
dmarx Jun 13, 2022
0a75d86
let's be honest: probably just making things worse here. pretty sure …
dmarx Jun 13, 2022
47f7304
rgb and pixel models working, vqgan still broken
dmarx Jun 13, 2022
d0827f5
getting closer
dmarx Jun 13, 2022
4189765
almost there
dmarx Jun 13, 2022
8f195be
one last coverage test to pass
dmarx Jun 13, 2022
866dab4
almost there....
dmarx Jun 13, 2022
1398493
passes coverage
dmarx Jun 13, 2022
9091891
vqgan tests passing
dmarx Jun 13, 2022
34c1a6f
skeleton DIP image model and tests
dmarx May 31, 2022
c188a2c
dip passes minimal unit tests
dmarx May 31, 2022
be84527
works with basic params, needs pixel_size=1 and reencode_each_frame=f…
dmarx May 31, 2022
63e9b4f
added encode_image for DIP
dmarx May 31, 2022
f89fe61
fixed DIP pixel_size scaling.
dmarx May 31, 2022
c3a3b47
added 1/10 offset params lr scaling on DIP by default.
dmarx May 31, 2022
c2e8984
fixed default dip input depth
dmarx May 31, 2022
0d4e555
added dip tests
dmarx Jun 10, 2022
ae35920
added pipfile to sate the git gods
dmarx Jun 10, 2022
a6496c6
incorporated new EMA into DIP
dmarx Jun 13, 2022
4795e03
suppress logging messages
dmarx Jun 13, 2022
80bad6e
shit is super broken. I think there's no more avoiding cleaning up ho…
dmarx Jun 13, 2022
755f14b
added note
dmarx Jun 14, 2022
02a9831
added note
dmarx Jun 14, 2022
1552014
h8 this. saving progress, but I think I need to just backtrack and si…
dmarx Jun 14, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ verify_ssl = false
name = "pytorch"

[packages]
tensorflow = "==2.7.0"
transformers = "==4.15.0"
gdown = "===4.2.0"
ftfy = "==6.0.3"
Expand All @@ -30,6 +31,7 @@ jupyter = "*"
imageio = "==2.4.1"
PyGLM = "==2.5.7"
adjustText = "*"

Pillow = "*"
torch = "*"
torchvision = "*"
Expand All @@ -48,6 +50,9 @@ mmc = {git = "https://github.com/dmarx/Multi-Modal-Comparators"}
pytest = "*"
pre-commit = "*"
click = "==8.0.4"
#Pillow = "==7.1.2"
#pyttitools-core = {editable = true, path = "."}
deep-image-prior = {path = "./deep-image-prior"}
black = "*"

[requires]
Expand Down
5 changes: 5 additions & 0 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def train(
if self.embedder is not None:
for mb_i in range(gradient_accumulation_steps):
# logger.debug(mb_i)
# logger.debug(self.image_rep.shape)
logger.debug(type(self.image_rep))
logger.debug(z.shape)
image_embeds, offsets, sizes = self.embedder(self.image_rep, input=z)

t = 1
Expand All @@ -317,6 +320,8 @@ def train(
for prompt in prompts
}

# oh.. uh... image_losses and auglosses don't actually depend on an embedder being attached.
# Maybe this is why limited palette wasn't initializing properly?
losses, losses_raw = zip(
*map(unpack_dict, [prompt_losses, aug_losses, image_losses])
# *map(unpack_dict, [prompt_losses])
Expand Down
306 changes: 298 additions & 8 deletions src/pytti/LossAug/LatentLossClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import copy, re
from pytti import DEVICE, fetch, parse, vram_usage_mode

from loguru import logger


class LatentLoss(MSELoss):
@torch.no_grad()
Expand All @@ -17,26 +19,70 @@ def __init__(
name="direct target loss",
image_shape=None,
):
super().__init__(comp, weight, stop, name, image_shape)
super().__init__(
comp, weight, stop, name, image_shape
) # this really should link back to the image model...
logger.debug(type(comp)) # inits to image tensor
self.pil_image = None
self.has_latent = False
w, h = image_shape
self.direct_loss = MSELoss(
TF.resize(comp.clone(), (h, w)), weight, stop, name, image_shape
)
comp_adjusted = TF.resize(comp.clone(), (h, w))
# try:
# comp_adjusted = TF.resize(comp.clone(), (h, w))
# except:
# # comp_adjusted = comp.clone()
# # Need to convert the latent to its image form
# comp_adjusted = img_model.decode_tensor(comp.clone())
self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape)

@torch.no_grad()
def set_comp(self, pil_image, device=DEVICE):
"""
sets the DIRECT loss anchor "comp" to the tensorized image.
"""
logger.debug(type(pil_image))
self.pil_image = pil_image
self.has_latent = False
self.direct_loss.set_comp(pil_image.resize(self.image_shape, Image.LANCZOS))
im_resized = pil_image.resize(
self.image_shape, Image.LANCZOS
) # to do: ResizeRight
# self.direct_loss.set_comp(im_resized)
self.direct_loss.set_comp(im_resized)

@classmethod
def convert_input(cls, input, img):
"""
Converts the input image tensor to the image representation of the image model.
E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
"""
logger.debug(type(input)) # pretty sure this is gonna be tensor
# return input # this is the default MSE loss version
return img.make_latent(input)

@classmethod
def default_comp(cls, img_model=None, *args, **kargs):
logger.debug("default_comp")
logger.debug(type(img_model))
device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu"
if img_model is None:
return torch.zeros(1, 1, 1, 1, device=device)
return img_model.default_comp(*args, **kargs)

@classmethod
@vram_usage_mode("Latent Image Loss")
@torch.no_grad()
def TargetImage(
cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE
cls,
prompt_string,
image_shape,
pil_image=None,
is_path=False,
device=DEVICE,
img_model=None,
):
logger.debug(
type(pil_image)
) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image?
text, weight, stop = parse(
prompt_string, r"(?<!^http)(?<!s):|:(?!/)", ["", "1", "-inf"]
)
Expand All @@ -48,7 +94,8 @@ def TargetImage(
comp = (
MSELoss.make_comp(pil_image)
if pil_image is not None
else torch.zeros(1, 1, 1, 1, device=device)
# else torch.zeros(1, 1, 1, 1, device=device)
else cls.default_comp(img_model=img_model)
)
out = cls(comp, weight, stop, text + " (latent)", image_shape)
if pil_image is not None:
Expand All @@ -61,11 +108,254 @@ def set_mask(self, mask, inverted=False):
super().set_mask(mask, inverted)

def get_loss(self, input, img):
logger.debug(type(input)) # Tensor
logger.debug(input.shape) # this is an image tensor
logger.debug(type(img)) # DIPImage
logger.debug(type(self.comp)) # Tensor
logger.debug(
self.comp.shape
) # [1 1 1 1] -> from target image constructor when no input image provided

# why is the latent comp only set here? why not in the __init__ and set_comp?
if not self.has_latent:
# make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
# if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
latent = img.make_latent(self.pil_image)
logger.debug(type(latent)) # EMAParametersDict
logger.debug(type(self.comp)) # torch.Tensor
with torch.no_grad():
self.comp.set_(latent.clone())
if type(latent) == type(self.comp):
self.comp.set_(latent.clone())
# else:

self.has_latent = True

l1 = super().get_loss(img.get_latent_tensor(), img) / 2
l2 = self.direct_loss.get_loss(input, img) / 10
return l1 + l2


######################################################################

# fuck it, let's just make a dip latent loss from scratch.


# The issue we're resolving here is that by inheriting from the MSELoss,
# I can't easily set the comp to the parameters of the image model.

from pytti.LossAug.BaseLossClass import Loss
from pytti.image_models.ema import EMAImage, EMAParametersDict
from pytti.rotoscoper import Rotoscoper

import deep_image_prior
import deep_image_prior.models
from deep_image_prior.models import (
get_hq_skip_net,
get_non_offset_params,
get_offset_params,
)


def load_dip(input_depth, num_scales, offset_type, offset_groups, device):
dip_net = get_hq_skip_net(
input_depth,
skip_n33d=192,
skip_n33u=192,
skip_n11=4,
num_scales=num_scales,
offset_type=offset_type,
offset_groups=offset_groups,
).to(device)

return dip_net


class LatentLossDIP(Loss):
@torch.no_grad()
def __init__(
self,
comp,
weight=0.5,
stop=-math.inf,
name="direct target loss",
image_shape=None,
device=None,
):
##################################################################
super().__init__(weight, stop, name, device)
if image_shape is None:
raise
# height, width = comp.shape[-2:]
# image_shape = (width, height)
self.image_shape = image_shape
self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device))
self.use_mask = False
##################################################################
self.pil_image = None
self.has_latent = False
logger.debug(type(comp)) # inits to image tensor
if comp is None:
comp = self.default_comp()
if isinstance(comp, EMAParametersDict):
logger.debug("initializing loss from latent")
self.register_module("comp", comp)
self.has_latent = True
else:
w, h = image_shape
comp_adjusted = TF.resize(comp.clone(), (h, w))
# try:
# comp_adjusted = TF.resize(comp.clone(), (h, w))
# except:
# # comp_adjusted = comp.clone()
# # Need to convert the latent to its image form
# comp_adjusted = img_model.decode_tensor(comp.clone())
self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape)

##################################################################

logger.debug(type(comp))

@classmethod
def default_comp(*args, **kargs):
logger.debug("default_comp")
device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu"
net = load_dip(
input_depth=32,
num_scales=7,
offset_type="none",
offset_groups=4,
device=device,
)
return EMAParametersDict(z=net, decay=0.99, device=device)

###################################################################################

@torch.no_grad()
def set_comp(self, pil_image, device=DEVICE):
"""
sets the DIRECT loss anchor "comp" to the tensorized image.
"""
logger.debug(type(pil_image))
self.pil_image = pil_image
self.has_latent = False
im_resized = pil_image.resize(
self.image_shape, Image.LANCZOS
) # to do: ResizeRight
# self.direct_loss.set_comp(im_resized)

im_tensor = (
TF.to_tensor(pil_image)
.unsqueeze(0)
.to(device, memory_format=torch.channels_last)
)

if hasattr(self, "direct_loss"):
self.direct_loss.set_comp(im_tensor)
else:
self.direct_loss = MSELoss(
im_tensor, self.weight, self.stop, self.name, self.image_shape
)
# self.direct_loss.set_comp(im_resized)

@classmethod
def convert_input(cls, input, img):
"""
Converts the input image tensor to the image representation of the image model.
E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
"""
logger.debug(type(input)) # pretty sure this is gonna be tensor
# return input # this is the default MSE loss version
return img.make_latent(input)

@classmethod
@vram_usage_mode("Latent Image Loss")
@torch.no_grad()
def TargetImage(
cls,
prompt_string,
image_shape,
pil_image=None,
is_path=False,
device=DEVICE,
img_model=None,
):
logger.debug(
type(pil_image)
) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image?
text, weight, stop = parse(
prompt_string, r"(?<!^http)(?<!s):|:(?!/)", ["", "1", "-inf"]
)
weight, mask = parse(weight, r"_", ["1", ""])
text = text.strip()
mask = mask.strip()
if pil_image is None and text != "" and is_path:
pil_image = Image.open(fetch(text)).convert("RGB")
comp = (
MSELoss.make_comp(pil_image)
if pil_image is not None
# else torch.zeros(1, 1, 1, 1, device=device)
else cls.default_comp(img_model=img_model)
)
out = cls(comp, weight, stop, text + " (latent)", image_shape)
if pil_image is not None:
out.set_comp(pil_image)
if (
mask
): # this will break if there's no pil_image since the direct_loss won't be initialized
out.set_mask(mask)
return out

def set_mask(self, mask, inverted=False):
self.direct_loss.set_mask(mask, inverted)
# super().set_mask(mask, inverted)
# if device is None:
device = self.device
if isinstance(mask, str) and mask != "":
if mask[0] == "-":
mask = mask[1:]
inverted = True
if mask.strip()[-4:] == ".mp4":
r = Rotoscoper(mask, self)
r.update(0)
return
mask = Image.open(fetch(mask)).convert("L")
if isinstance(mask, Image.Image):
with vram_usage_mode("Masks"):
mask = (
TF.to_tensor(mask)
.unsqueeze(0)
.to(device, memory_format=torch.channels_last)
)
if mask not in ["", None]:
self.mask.set_(mask if not inverted else (1 - mask))
self.use_mask = mask not in ["", None]

def get_loss(self, input, img):
logger.debug(type(input)) # Tensor
logger.debug(input.shape) # this is an image tensor
logger.debug(type(img)) # DIPImage
logger.debug(type(self.comp)) # EMAParametersDict
# logger.debug(
# self.comp.shape
# ) # [1 1 1 1] -> from target image constructor when no input image provided

# why is the latent comp only set here? why not in the __init__ and set_comp?
if not self.has_latent:
raise
# make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
# if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
latent = img.make_latent(self.pil_image)
logger.debug(type(latent)) # EMAParametersDict
logger.debug(type(self.comp)) # torch.Tensor
with torch.no_grad():
if type(latent) == type(self.comp):
self.comp.set_(latent.clone())
# else:

self.has_latent = True

estimated_image = self.comp.get_image_tensor()

l1 = super().get_loss(img.get_latent_tensor(), img) / 2
l2 = self.direct_loss.get_loss(input, img) / 10
return l1 + l2
Loading