From e59295a0d91e6c31dd9b8ac1b426ec16f590d10b Mon Sep 17 00:00:00 2001 From: Bardia Shahrestani Date: Mon, 20 Jun 2022 03:52:21 -0400 Subject: [PATCH 01/29] Fixed broken backup remover Replaced subprocess.run with python's native os.remove() --- src/pytti/update_func.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/pytti/update_func.py b/src/pytti/update_func.py index 1d3606e..95cd4f8 100644 --- a/src/pytti/update_func.py +++ b/src/pytti/update_func.py @@ -142,12 +142,7 @@ def save_out( # if we can't deprecate the folder removal entirely. What # is the purpose of "backups" here? Just use the frames that # are being written to disk. - subprocess.run( - [ - "rm", - f"backup/{file_namespace}/{base_name}_{n-backups}.bak", - ] - ) + os.remove(f"backup/{file_namespace}/{base_name}_{n-backups}.bak") j = i + 1 From 238ccbf16adea76450a74a730e33de9a9f8fc217 Mon Sep 17 00:00:00 2001 From: Bardia Shahrestani Date: Mon, 20 Jun 2022 04:40:26 -0400 Subject: [PATCH 02/29] Made backup removal path cross-platform replaced hardcoded path with os.path.join(backup_path) for automatic backup removal --- src/pytti/update_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/update_func.py b/src/pytti/update_func.py index 95cd4f8..a411519 100644 --- a/src/pytti/update_func.py +++ b/src/pytti/update_func.py @@ -142,7 +142,7 @@ def save_out( # if we can't deprecate the folder removal entirely. What # is the purpose of "backups" here? Just use the frames that # are being written to disk. - os.remove(f"backup/{file_namespace}/{base_name}_{n-backups}.bak") + os.remove(os.path.join("backup", f"{file_namespace}",f"{base_name}_{n-backups}.bak")) j = i + 1 From ade92abbb3a3875a77ff0dd207688279076c40cc Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 20 Jun 2022 10:50:18 -0700 Subject: [PATCH 03/29] modified backup removal to use pathlib --- src/pytti/update_func.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/pytti/update_func.py b/src/pytti/update_func.py index a411519..b90a98f 100644 --- a/src/pytti/update_func.py +++ b/src/pytti/update_func.py @@ -135,14 +135,11 @@ def save_out( filename = f"backup/{file_namespace}/{base_name}_{n}.bak" torch.save(img.state_dict(), filename) if n > backups: - - # YOOOOOOO let's not start shell processes unnecessarily - # and then execute commands using string interpolation. - # Replace this with a pythonic folder removal, then see - # if we can't deprecate the folder removal entirely. What - # is the purpose of "backups" here? Just use the frames that - # are being written to disk. - os.remove(os.path.join("backup", f"{file_namespace}",f"{base_name}_{n-backups}.bak")) + fname = f"{base_name}_{n-backups}.bak" + fpath = Path("backup") / file_namespace / fname + fpath.unlink( + missing_ok=True + ) # delete the file. if file not found, nothing happens. j = i + 1 From 24cceeeec8d370e5c5c77e446f05fba5f8dc3a91 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 18:35:35 -0700 Subject: [PATCH 04/29] added proxy property for improved ema --- src/pytti/image_models/ema.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index ea04791..d923cf6 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -45,3 +45,7 @@ def decode_tensor(self): def decode(self, tensor): raise NotImplementedError + + @property + def representation_parameters(self): + return self.tensor From a96d0811bd4178bfafae75a27774090d4f972791 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 18:38:16 -0700 Subject: [PATCH 05/29] migrated vqgan image to proxy --- src/pytti/image_models/vqgan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 29210cc..4e8473b 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -188,7 +188,7 @@ def __init__( def clone(self): dummy = VQGANImage(*self.image_shape) with torch.no_grad(): - dummy.tensor.set_(self.tensor.clone()) + dummy.representation_parameters.set_(self.representation_parameters.clone()) dummy.accum.set_(self.accum.clone()) dummy.biased.set_(self.biased.clone()) dummy.average.set_(self.average.clone()) @@ -198,7 +198,7 @@ def clone(self): def get_latent_tensor(self, detach=False, device=None): if device is None: device = self.device - z = self.tensor + z = self.representation_parameters if detach: z = z.detach() z_q = vector_quantize(z, self.vqgan_quantize_embedding).movedim(3, 1).to(device) @@ -226,7 +226,7 @@ def encode_image(self, pil_image, device=None, **kwargs): pil_image = pil_image.resize(self.image_shape, Image.LANCZOS) pil_image = TF.to_tensor(pil_image) z, *_ = self.vqgan_encode(pil_image.unsqueeze(0).to(device) * 2 - 1) - self.tensor.set_(z.movedim(1, 3)) + self.representation_parameters.set_(z.movedim(1, 3)) self.reset() @torch.no_grad() @@ -245,7 +245,7 @@ def make_latent(self, pil_image, device=None): @torch.no_grad() def encode_random(self): - self.tensor.set_(self.rand_latent()) + self.representation_parameters.set_(self.rand_latent()) self.reset() def rand_latent(self, device=None, vqgan_quantize_embedding=None): From 08bee596bd62618288f117f52227b79000091c91 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 18:47:10 -0700 Subject: [PATCH 06/29] pushed proxy property down to base class --- src/pytti/image_models/differentiable_image.py | 7 +++++++ src/pytti/image_models/ema.py | 4 ---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 79ea590..7159c65 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -120,3 +120,10 @@ def forward(self): return self.decode_training_tensor() else: return self.decode_tensor() + + @property + def representation_parameters(self): + # yeah I should really make this class an ABC + if not hasattr(self, "representation_parameters"): + raise NotImplementedError + return self.tensor diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index d923cf6..ea04791 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -45,7 +45,3 @@ def decode_tensor(self): def decode(self, tensor): raise NotImplementedError - - @property - def representation_parameters(self): - return self.tensor From 0a75d8611ef5182a904c854e585ebbc36148dfb8 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 20:06:24 -0700 Subject: [PATCH 07/29] let's be honest: probably just making things worse here. pretty sure this is unnecessarily complicated. --- .../image_models/differentiable_image.py | 55 +++++++++++++++++-- src/pytti/image_models/ema.py | 6 +- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 7159c65..9bc3a44 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -1,4 +1,6 @@ import copy + +import torch from torch import nn import numpy as np from PIL import Image @@ -24,6 +26,10 @@ def __init__(self, width: int, height: int, pixel_format: str = "RGB"): self.output_axes = ("x", "y", "s") self.lr = 0.02 self.latent_strength = 0 + self.image_representation_parameters = ImageRepresentationalParameters( + width, height + ) + # self.tensor = self.image_representation_parameters._new() def decode_training_tensor(self): """ @@ -123,7 +129,48 @@ def forward(self): @property def representation_parameters(self): - # yeah I should really make this class an ABC - if not hasattr(self, "representation_parameters"): - raise NotImplementedError - return self.tensor + return self.image_representation_parameters._container + + ## yeah I should really make this class an ABC + # if not hasattr(self, "representation_parameters"): + # raise NotImplementedError + # return self.tensor + + +class ImageRepresentationalParameters(nn.Module): + """ + Base class for defining parameters of differentiable images + width: (positive integer) image width in pixels + height: (positive integer) image height in pixels + pixel_format: (string) PIL image mode. Either 'L','RGB','I', or 'F' + """ + + def __init__(self, width: int, height: int): + super().__init__() + self.image_shape = (width, height) + self._container = self._new() + + def _new(self): + width, height = self.image_shape + return nn.Parameter( + torch.zeros(1, 3, height, width).to( + device=self.device, memory_format=torch.channels_last + ) + ) + # self.tensor = ImageRepresentationalParameters(width, height, pixel_format) + + ################ + def mul_(): + pass + + def add_(): + pass + + def copy(): + pass + + def div_(): + pass + + def set_(): + pass diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index ea04791..9c92c10 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -11,7 +11,7 @@ class EMAImage(DifferentiableImage): def __init__(self, width, height, tensor, decay): super().__init__(width, height) - self.tensor = nn.Parameter(tensor) + # self.representation_parameters = nn.Parameter(tensor) self.register_buffer("biased", torch.zeros_like(tensor)) self.register_buffer("average", torch.zeros_like(tensor)) self.decay = decay @@ -24,7 +24,7 @@ def update(self): raise RuntimeError("update() should only be called during training") self.accum.mul_(self.decay) self.biased.mul_(self.decay) - self.biased.add_((1 - self.decay) * self.tensor) + self.biased.add_((1 - self.decay) * self.representation_parameters) self.average.copy_(self.biased) self.average.div_(1 - self.accum) @@ -38,7 +38,7 @@ def reset(self): self.update() def decode_training_tensor(self): - return self.decode(self.tensor) + return self.decode(self.representation_parameters) def decode_tensor(self): return self.decode(self.average) From 47f7304a7d56a140e77246fc7ffcfef0db531619 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 20:25:01 -0700 Subject: [PATCH 08/29] rgb and pixel models working, vqgan still broken --- src/pytti/image_models/differentiable_image.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 9bc3a44..d084fe2 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -17,7 +17,7 @@ class DifferentiableImage(nn.Module): pixel_format: (string) PIL image mode. Either 'L','RGB','I', or 'F' """ - def __init__(self, width: int, height: int, pixel_format: str = "RGB"): + def __init__(self, width: int, height: int, pixel_format: str = "RGB", device=None): super().__init__() if pixel_format not in SUPPORTED_MODES: raise ValueError(f"Pixel format {pixel_format} is not supported.") @@ -29,6 +29,8 @@ def __init__(self, width: int, height: int, pixel_format: str = "RGB"): self.image_representation_parameters = ImageRepresentationalParameters( width, height ) + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.tensor = self.image_representation_parameters._new() def decode_training_tensor(self): @@ -145,9 +147,12 @@ class ImageRepresentationalParameters(nn.Module): pixel_format: (string) PIL image mode. Either 'L','RGB','I', or 'F' """ - def __init__(self, width: int, height: int): + def __init__(self, width: int, height: int, device=None): super().__init__() self.image_shape = (width, height) + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device self._container = self._new() def _new(self): From d0827f59eb04a7e37309dd9334b02d9d96db6dc3 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 23:01:16 -0700 Subject: [PATCH 09/29] getting closer --- .../image_models/differentiable_image.py | 59 +++++++-------- src/pytti/image_models/ema.py | 72 ++++++++++++++++++- src/pytti/image_models/vqgan.py | 10 +++ 3 files changed, 105 insertions(+), 36 deletions(-) diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index d084fe2..95b5f08 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -6,28 +6,22 @@ from PIL import Image from pytti.tensor_tools import named_rearrange -SUPPORTED_MODES = ["L", "RGB", "I", "F"] - class DifferentiableImage(nn.Module): """ Base class for defining differentiable images width: (positive integer) image width in pixels height: (positive integer) image height in pixels - pixel_format: (string) PIL image mode. Either 'L','RGB','I', or 'F' """ - def __init__(self, width: int, height: int, pixel_format: str = "RGB", device=None): + def __init__(self, width: int, height: int, device=None): super().__init__() - if pixel_format not in SUPPORTED_MODES: - raise ValueError(f"Pixel format {pixel_format} is not supported.") self.image_shape = (width, height) - self.pixel_format = format self.output_axes = ("x", "y", "s") self.lr = 0.02 self.latent_strength = 0 self.image_representation_parameters = ImageRepresentationalParameters( - width, height + width=width, height=height ) if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -144,38 +138,35 @@ class ImageRepresentationalParameters(nn.Module): Base class for defining parameters of differentiable images width: (positive integer) image width in pixels height: (positive integer) image height in pixels - pixel_format: (string) PIL image mode. Either 'L','RGB','I', or 'F' """ - def __init__(self, width: int, height: int, device=None): + def __init__(self, width: int, height: int, z=None, device=None): super().__init__() - self.image_shape = (width, height) + self.width = width + self.height = height if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device - self._container = self._new() + self._container = self._new(z) - def _new(self): - width, height = self.image_shape - return nn.Parameter( - torch.zeros(1, 3, height, width).to( + def _new(self, z=None): + if z is None: + # I think this can all go in the constructor and doesn't need to call .to() + z = torch.zeros(1, 3, self.height, self.width).to( device=self.device, memory_format=torch.channels_last ) - ) - # self.tensor = ImageRepresentationalParameters(width, height, pixel_format) - - ################ - def mul_(): - pass - - def add_(): - pass - - def copy(): - pass - - def div_(): - pass - - def set_(): - pass + return nn.Parameter(z) + + +class LatentTensor(ImageRepresentationalParameters): + pass + # def __init__(self, z, device=None): + # super().__init__(z.shape[1], z.shape[2], device=device) + # #self._container = z + # self._z = z + # def _new(self): + # return nn.Parameter( + # torch.zeros(1, 3, height, width).to( + # device=self.device, memory_format=torch.channels_last + # ) + # ) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index 9c92c10..89f1cae 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -1,17 +1,85 @@ import torch from torch import nn -from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.differentiable_image import ( + DifferentiableImage, + ImageRepresentationalParameters, +) + + +class EMATensor(nn.Module): + """implmeneted by Katherine Crowson""" + + def __init__(self, tensor, decay): + super().__init__() + self.tensor = nn.Parameter(tensor) + self.register_buffer("biased", torch.zeros_like(tensor)) + self.register_buffer("average", torch.zeros_like(tensor)) + self.decay = decay + self.register_buffer("accum", torch.tensor(1.0)) + self.update() + + @torch.no_grad() + def update(self): + if not self.training: + raise RuntimeError("update() should only be called during training") + + self.accum *= self.decay + self.biased.mul_(self.decay) + self.biased.add_((1 - self.decay) * self.tensor) + self.average.copy_(self.biased) + self.average.div_(1 - self.accum) + + def forward(self): + if self.training: + return self.tensor + return self.average + + +class EMAParametersDict(ImageRepresentationalParameters): + """ + LatentTensor with a singleton dimension for the EMAParameters + """ + + def __init__(self, z=None, decay=0.99, device=None): + super(ImageRepresentationalParameters).__init__() + self.decay = decay + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + self._container = self._new(z) + + def _new(self, z=None): + if z is None: + # I think this can all go in the constructor and doesn't need to call .to() + z = torch.zeros(1, 3, self.height, self.width).to( + device=self.device, memory_format=torch.channels_last + ) + d_ = z + if not isinstance(z, dict): + if hasattr(z, "named_parameters"): + d_ = {name: EMATensor(param) for name, param in z.named_parameters()} + return d_ class EMAImage(DifferentiableImage): + def __init__(self, width, height, tensor, decay, device=None): + super().__init__(width=width, height=height, device=device) + self.image_representation_parameters = EMAParametersDict( + z=tensor, decay=decay, device=device + ) + + +class EMAImage_old(DifferentiableImage): """ Base class for differentiable images with Exponential Moving Average filtering Based on code by Katherine Crowson """ def __init__(self, width, height, tensor, decay): - super().__init__(width, height) + # super().__init__(width, height) + super().__init__(width=width, height=height, z=tensor) # self.representation_parameters = nn.Parameter(tensor) + # self.image_representation_parameters._container = nn.Parameter(tensor) self.register_buffer("biased", torch.zeros_like(tensor)) self.register_buffer("average", torch.zeros_like(tensor)) self.decay = decay diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 4e8473b..fb90dc6 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -12,6 +12,7 @@ import torch from torch.nn import functional as F from pytti.image_models import EMAImage +from pytti.image_models.differentiable_image import LatentTensor from torchvision.transforms import functional as TF from PIL import Image from omegaconf import OmegaConf @@ -185,6 +186,15 @@ def __init__( self.vqgan_decode = model.decode self.vqgan_encode = model.encode + ################################# + + self.image_representation_parameters = LatentTensor( + width=width, + height=height, + z=z, + device=self.device, + ) + def clone(self): dummy = VQGANImage(*self.image_shape) with torch.no_grad(): From 41897656783076d8ecfeb48a9a6f50d082a6c3b5 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 23:27:08 -0700 Subject: [PATCH 10/29] almost there --- .../image_models/differentiable_image.py | 24 ++++----- src/pytti/image_models/ema.py | 12 +++++ src/pytti/image_models/vqgan.py | 25 +++++++--- tests/image_models/test_im_vqgan.py | 49 +++++++++++++++++++ 4 files changed, 90 insertions(+), 20 deletions(-) create mode 100644 tests/image_models/test_im_vqgan.py diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 95b5f08..674130d 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -158,15 +158,15 @@ def _new(self, z=None): return nn.Parameter(z) -class LatentTensor(ImageRepresentationalParameters): - pass - # def __init__(self, z, device=None): - # super().__init__(z.shape[1], z.shape[2], device=device) - # #self._container = z - # self._z = z - # def _new(self): - # return nn.Parameter( - # torch.zeros(1, 3, height, width).to( - # device=self.device, memory_format=torch.channels_last - # ) - # ) +# class LatentTensor(ImageRepresentationalParameters): +# pass +# def __init__(self, z, device=None): +# super().__init__(z.shape[1], z.shape[2], device=device) +# #self._container = z +# self._z = z +# def _new(self): +# return nn.Parameter( +# torch.zeros(1, 3, height, width).to( +# device=self.device, memory_format=torch.channels_last +# ) +# ) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index 89f1cae..7ffff0b 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -60,6 +60,14 @@ def _new(self, z=None): d_ = {name: EMATensor(param) for name, param in z.named_parameters()} return d_ + def clone(self): + d_ = {k: v.clone() for k, v in self._container.items()} + return EMAParametersDict(z=d_, decay=self.decay, device=self.device) + + def update(self): + for param in self._container.values(): + param.update() + class EMAImage(DifferentiableImage): def __init__(self, width, height, tensor, decay, device=None): @@ -69,6 +77,10 @@ def __init__(self, width, height, tensor, decay, device=None): ) +class LatentTensor(EMAImage): + pass + + class EMAImage_old(DifferentiableImage): """ Base class for differentiable images with Exponential Moving Average filtering diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index fb90dc6..b541cc3 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -11,8 +11,11 @@ from pytti import replace_grad, clamp_with_grad, vram_usage_mode import torch from torch.nn import functional as F -from pytti.image_models import EMAImage -from pytti.image_models.differentiable_image import LatentTensor + +# from pytti.image_models import EMAImage +# from pytti.image_models.differentiable_image import LatentTensor +# from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage from torchvision.transforms import functional as TF from PIL import Image from omegaconf import OmegaConf @@ -188,12 +191,12 @@ def __init__( ################################# - self.image_representation_parameters = LatentTensor( - width=width, - height=height, - z=z, - device=self.device, - ) + # self.image_representation_parameters = LatentTensor( + # width=width, + # height=height, + # z=z, + # device=self.device, + # ) def clone(self): dummy = VQGANImage(*self.image_shape) @@ -220,6 +223,12 @@ def get_preferred_loss(cls): return LatentLoss + def decode_training_tensor(self): + return self.decode(self.representation_parameters) + + def decode_tensor(self): + return self.decode(self.average) + def decode(self, z, device=None): if device is None: device = self.device diff --git a/tests/image_models/test_im_vqgan.py b/tests/image_models/test_im_vqgan.py new file mode 100644 index 0000000..332dd77 --- /dev/null +++ b/tests/image_models/test_im_vqgan.py @@ -0,0 +1,49 @@ +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.vqgan import VQGANImage + +from pathlib import Path + + +def test_init_model(): + models_parent_dir = "~/.cache/vqgan/" + vqgan_model = "coco" + model_artifacts_path = Path(models_parent_dir) / "vqgan" + VQGANImage.init_vqgan(vqgan_model, model_artifacts_path) + img = VQGANImage(512, 512, 1) + # img.encode_random() + + +def test_init(): + obj = VQGANImage(512, 512) + assert obj + + +def test_update(): + obj = VQGANImage(512, 512) + obj.update() + + +def test_forward(): + obj = VQGANImage(512, 512) + obj.forward() + + +def test_decode_training_tensor(): + obj = VQGANImage(512, 512) + obj.decode_training_tensor() + + +def test_decode_tensor(): + obj = VQGANImage(512, 512) + obj.decode_tensor() + + +def test_clone(): + obj = VQGANImage(512, 512) + obj.clone() + + +def test_get_latent_tensor(): + obj = VQGANImage(512, 512) + obj.get_latent_tensor() From 8f195be05553a39b2e31d6ec5c03e4a11babd910 Mon Sep 17 00:00:00 2001 From: David Marx Date: Sun, 12 Jun 2022 23:32:58 -0700 Subject: [PATCH 11/29] one last coverage test to pass --- src/pytti/image_models/vqgan.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index b541cc3..06e223b 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -199,13 +199,17 @@ def __init__( # ) def clone(self): + # dummy = VQGANImage(*self.image_shape) + # with torch.no_grad(): + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + # dummy.accum.set_(self.accum.clone()) + # dummy.biased.set_(self.biased.clone()) + # dummy.average.set_(self.average.clone()) + # dummy.decay = self.decay + # return dummy dummy = VQGANImage(*self.image_shape) with torch.no_grad(): dummy.representation_parameters.set_(self.representation_parameters.clone()) - dummy.accum.set_(self.accum.clone()) - dummy.biased.set_(self.biased.clone()) - dummy.average.set_(self.average.clone()) - dummy.decay = self.decay return dummy def get_latent_tensor(self, detach=False, device=None): From 866dab492efd84239b81c53ce0dd679b6535186e Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 13 Jun 2022 01:19:20 -0700 Subject: [PATCH 12/29] almost there.... --- src/pytti/image_models/ema.py | 71 +++++++++++++++++++++++++++++++-- src/pytti/image_models/vqgan.py | 15 ++++++- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index 7ffff0b..242e16c 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -34,6 +34,20 @@ def forward(self): return self.tensor return self.average + def clone(self): + new = EMATensor(self.tensor.clone(), self.decay) + new.accum.copy_(self.accum) + new.biased.copy_(self.biased) + new.average.copy_(self.average) + return new + + def set_(self, other): + self.tensor.set_(other.tensor) + self.accum.set_(other.accum) + self.biased.set_(other.biased) + self.average.set_(other.average) + # self.update() + class EMAParametersDict(ImageRepresentationalParameters): """ @@ -54,13 +68,49 @@ def _new(self, z=None): z = torch.zeros(1, 3, self.height, self.width).to( device=self.device, memory_format=torch.channels_last ) - d_ = z - if not isinstance(z, dict): - if hasattr(z, "named_parameters"): - d_ = {name: EMATensor(param) for name, param in z.named_parameters()} + # d_ = z + d_ = {} + if isinstance(z, EMAParametersDict): + for k, v in z.items(): + d_[k] = EMATensor(v.tensor, self.decay) + elif isinstance(z, EMATensor): + d_["z"] = z.clone() + elif hasattr(z, "named_parameters"): + d_ = { + name: EMATensor(param, decay=self.decay) + for name, param in z.named_parameters() + } + elif isinstance(z, dict): + print(type(z)) + for k, v in z.items(): + # print(k) + # print(type(v)) + # d_[k] = EMATensor(v, self.decay) + # d_[k] = EMATensor(v.clone(), self.decay) + if isinstance(v, EMATensor): + d_[k] = v.clone() + elif isinstance(z, torch.Tensor): + d_["z"] = EMATensor(z, decay=self.decay) + else: + raise ValueError( + "z must be a dict, torch.nn.Module, EMATensor, EMAParametersDict, or torch.Tensor" + ) + + # if not isinstance(z, dict): + ##if hasattr(z, "named_parameters"): + ## d_ = {name: EMATensor(param, decay=self.decay) for name, param in z.named_parameters()} + # else: + # d_['z'] = EMATensor(z, decay=self.decay) + # else: + # #d_ = {name: EMATensor(param, self.decay) for name, param in z.items()} + # try: + # d_ = {name: EMATensor(param.data, self.decay) for name, param in z.items()} + # except AttributeError: + # d_ = {name: EMATensor(param, self.decay) for name, param in z.items()} return d_ def clone(self): + # d_ = {k: v.clone() for k, v in self._container.items()} d_ = {k: v.clone() for k, v in self._container.items()} return EMAParametersDict(z=d_, decay=self.decay, device=self.device) @@ -68,6 +118,19 @@ def update(self): for param in self._container.values(): param.update() + def average(self): + return {k: v.average() for k, v in self._container.items()} + + def set_(self, d): + d_ = d + if isinstance(d, EMAParametersDict): + d_ = d._container + for k, v in d_.items(): + self._container[k].set_(v) + # self._container[k].tensor.set_(v) + # self._container[k].tensor.set_(v.tensor) + # self._container[k].tensor.set_(v.tensor) + class EMAImage(DifferentiableImage): def __init__(self, width, height, tensor, decay, device=None): diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 06e223b..8c18adf 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -209,12 +209,21 @@ def clone(self): # return dummy dummy = VQGANImage(*self.image_shape) with torch.no_grad(): - dummy.representation_parameters.set_(self.representation_parameters.clone()) + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + dummy.image_representation_parameters.set_( + self.image_representation_parameters.clone() + ) return dummy + @property + def representation_parameters(self): + return self.image_representation_parameters._container.get("z").tensor + def get_latent_tensor(self, detach=False, device=None): if device is None: device = self.device + # z = self.representation_parameters._container.get("z") + # z = self.image_representation_parameters._container.get("z").tensor z = self.representation_parameters if detach: z = z.detach() @@ -231,7 +240,9 @@ def decode_training_tensor(self): return self.decode(self.representation_parameters) def decode_tensor(self): - return self.decode(self.average) + # return self.decode(self.average) + # return self.decode(self.representation_parameters.average) + return self.decode(self.image_representation_parameters.average) def decode(self, z, device=None): if device is None: From 1398493dfa28af9760908159eb707471b7d928ac Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 13 Jun 2022 01:31:23 -0700 Subject: [PATCH 13/29] passes coverage --- src/pytti/image_models/ema.py | 3 ++- src/pytti/image_models/vqgan.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index 242e16c..d062254 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -118,8 +118,9 @@ def update(self): for param in self._container.values(): param.update() + @property def average(self): - return {k: v.average() for k, v in self._container.items()} + return {k: v.average for k, v in self._container.items()} def set_(self, d): d_ = d diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 8c18adf..92a624f 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -242,7 +242,8 @@ def decode_training_tensor(self): def decode_tensor(self): # return self.decode(self.average) # return self.decode(self.representation_parameters.average) - return self.decode(self.image_representation_parameters.average) + # return self.decode(self.image_representation_parameters.average) + return self.decode(self.image_representation_parameters.average["z"]) def decode(self, z, device=None): if device is None: From 9091891a644be6d4c8ac998fb94601123fd53f1f Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 13 Jun 2022 01:49:02 -0700 Subject: [PATCH 14/29] vqgan tests passing --- src/pytti/image_models/ema.py | 13 +++++++++++++ src/pytti/image_models/vqgan.py | 6 ++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index d062254..5e09a4f 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -48,6 +48,15 @@ def set_(self, other): self.average.set_(other.average) # self.update() + @torch.no_grad() + def reset(self): + if not self.training: + raise RuntimeError("reset() should only be called during training") + self.biased.set_(torch.zeros_like(self.biased)) + self.average.set_(torch.zeros_like(self.average)) + self.accum.set_(torch.ones_like(self.accum)) + self.update() + class EMAParametersDict(ImageRepresentationalParameters): """ @@ -132,6 +141,10 @@ def set_(self, d): # self._container[k].tensor.set_(v.tensor) # self._container[k].tensor.set_(v.tensor) + def reset(self): + for param in self._container.values(): + param.reset() + class EMAImage(DifferentiableImage): def __init__(self, width, height, tensor, decay, device=None): diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 92a624f..11b2d8e 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -262,7 +262,8 @@ def encode_image(self, pil_image, device=None, **kwargs): pil_image = TF.to_tensor(pil_image) z, *_ = self.vqgan_encode(pil_image.unsqueeze(0).to(device) * 2 - 1) self.representation_parameters.set_(z.movedim(1, 3)) - self.reset() + # self.reset() + self.image_representation_parameters.reset() @torch.no_grad() def make_latent(self, pil_image, device=None): @@ -281,7 +282,8 @@ def make_latent(self, pil_image, device=None): @torch.no_grad() def encode_random(self): self.representation_parameters.set_(self.rand_latent()) - self.reset() + # self.reset() + self.image_representation_parameters.reset() def rand_latent(self, device=None, vqgan_quantize_embedding=None): if device is None: From 34c1a6f719e101c4f8b45a71849c5b6044762ec3 Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 30 May 2022 19:46:36 -0700 Subject: [PATCH 15/29] skeleton DIP image model and tests --- src/pytti/image_models/__init__.py | 1 + src/pytti/image_models/deep_image_prior.py | 122 +++++++++++++++++++++ src/pytti/image_models/rgb_image.py | 2 +- tests/test_image_models.py | 48 ++++++++ 4 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 src/pytti/image_models/deep_image_prior.py create mode 100644 tests/test_image_models.py diff --git a/src/pytti/image_models/__init__.py b/src/pytti/image_models/__init__.py index b47a344..547b44d 100644 --- a/src/pytti/image_models/__init__.py +++ b/src/pytti/image_models/__init__.py @@ -9,3 +9,4 @@ from .pixel import PixelImage from .rgb_image import RGBImage from .vqgan import VQGANImage +from .deep_image_prior import DeepImagePrior diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py new file mode 100644 index 0000000..da37a13 --- /dev/null +++ b/src/pytti/image_models/deep_image_prior.py @@ -0,0 +1,122 @@ +from copy import deepcopy + +from loguru import logger + +from pytti import clamp_with_grad +import torch +from torch import nn +from torchvision.transforms import functional as TF +from pytti.image_models import DifferentiableImage, EMAImage +from PIL import Image +from torch.nn import functional as F + +# scavenging code snippets from: +# - https://github.com/LAION-AI/notebooks/blob/main/DALLE2-Prior%2BDeep-Image-Prior.ipynb + +import deep_image_prior +import deep_image_prior.models +from deep_image_prior.models import ( + get_hq_skip_net, + get_non_offset_params, + get_offset_params, +) + +# from deep_image_prior import get_hq_skip_net, get_non_offset_params, get_offset_params + +# foo = deep_image_prior.models + + +def load_dip(input_depth, num_scales, offset_type, offset_groups, device): + dip_net = get_hq_skip_net( + input_depth, + skip_n33d=192, + skip_n33u=192, + skip_n11=4, + num_scales=num_scales, + offset_type=offset_type, + offset_groups=offset_groups, + ).to(device) + + return dip_net + + +# class DeepImagePrior(EMAImage): +class DeepImagePrior(DifferentiableImage): + """ + https://github.com/nousr/deep-image-prior/ + """ + + def __init__( + self, + width, + height, + scale=1, + ########### + input_depth=3, + num_scales=7, + offset_type="none", + # offset_groups=1, + disable_deformable_convolutions=False, + lr=1e-3, + offset_lr_fac=1.0, + ########### + device="cuda", + **kwargs, + ): + super().__init__(width * scale, height * scale) + self.net = load_dip( + input_depth=input_depth, + num_scales=num_scales, + offset_type=offset_type, + offset_groups=0 if disable_deformable_convolutions else 4, + device=device, + ) + # self.tensor = self.net.params() + self.output_axes = ("n", "s", "y", "x") + self.scale = scale + self.device = device + + self._net_input = torch.randn([1, input_depth, width, height], device=device) + + self.lr = lr + self.offset_lr_fac = offset_lr_fac + # self._params = [ + # {'params': get_non_offset_params(net), 'lr': lr}, + # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac} + # ] + + def get_image_tensor(self): + with torch.cuda.amp.autocast(): + # out = net(net_input_noised * input_scale).float() + logger.debug(self.net) + logger.debug(self._net_input.shape) + out = self.net(self._net_input).float() + logger.debug(out.shape) + return out + + def get_latent_tensor(self, detach=False): + # pass + net = self.net + lr = self.lr + offset_lr_fac = self.offset_lr_fac + params = [ + {"params": get_non_offset_params(net), "lr": lr}, + {"params": get_offset_params(net), "lr": lr * offset_lr_fac}, + ] + # params = torch.cat( + # get_non_offset_params(net), + # get_offset_params(net) + # ) + return params + + def clone(self): + # dummy = super().__init__(*self.image_shape) + # with torch.no_grad(): + # #dummy.tensor.set_(self.tensor.clone()) + # dummy.net.copy_(self.net) + # dummy.accum.set_(self.accum.clone()) + # dummy.biased.set_(self.biased.clone()) + # dummy.average.set_(self.average.clone()) + # dummy.decay = self.decay + dummy = deepcopy(self) + return dummy diff --git a/src/pytti/image_models/rgb_image.py b/src/pytti/image_models/rgb_image.py index e6b39a6..cc02d22 100644 --- a/src/pytti/image_models/rgb_image.py +++ b/src/pytti/image_models/rgb_image.py @@ -6,7 +6,7 @@ from PIL import Image from torch.nn import functional as F - +# why doesn't this inherit from EMA? class RGBImage(DifferentiableImage): """ Naive RGB image representation diff --git a/tests/test_image_models.py b/tests/test_image_models.py new file mode 100644 index 0000000..2388f5c --- /dev/null +++ b/tests/test_image_models.py @@ -0,0 +1,48 @@ +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage +from pytti.image_models.deep_image_prior import DeepImagePrior + +### DIP ### + + +def test_dip_init(): + obj = DeepImagePrior(10, 10) + assert obj + + +def test_dip_update(): + obj = DeepImagePrior(10, 10) + obj.update() + + +def test_dip_forward(): + obj = DeepImagePrior(10, 10) + obj.forward() + + +def test_dip_decode_training_tensor(): + obj = DeepImagePrior(10, 10) + obj.decode_training_tensor() + + +def test_dip_decode_tensor(): + obj = DeepImagePrior(10, 10) + obj.decode_tensor() + + +def test_dip_clone(): + obj = DeepImagePrior(10, 10) + obj.clone() + + +def test_dip_get_latent_tensor(): + obj = DeepImagePrior(10, 10) + obj.get_latent_tensor() + + +def test_dip_get_image_tensor(): + obj = DeepImagePrior(512, 512) + obj.get_image_tensor() From c188a2ce86ecf1f78913161ab133656ab7b85bec Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 30 May 2022 20:20:33 -0700 Subject: [PATCH 16/29] dip passes minimal unit tests --- src/pytti/image_models/deep_image_prior.py | 16 ++++++++++++++-- tests/test_image_models.py | 17 ++++++----------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index da37a13..c5c3166 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -60,17 +60,28 @@ def __init__( lr=1e-3, offset_lr_fac=1.0, ########### + ema_val=0.99, + ########### device="cuda", **kwargs, ): super().__init__(width * scale, height * scale) - self.net = load_dip( + net = load_dip( input_depth=input_depth, num_scales=num_scales, offset_type=offset_type, offset_groups=0 if disable_deformable_convolutions else 4, device=device, ) + # z = self.get_latent_tensor() + # params = [ + # {'params': get_non_offset_params(net), 'lr': lr}, + # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac} + # ] + # z = torch.cat(get_non_offset_params(net), get_offset_params(net)) + # logger.debug(z.shape) + # super().__init__(width * scale, height * scale, z, ema_val) + self.net = net # self.tensor = self.net.params() self.output_axes = ("n", "s", "y", "x") self.scale = scale @@ -85,7 +96,8 @@ def __init__( # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac} # ] - def get_image_tensor(self): + # def get_image_tensor(self): + def decode_tensor(self): with torch.cuda.amp.autocast(): # out = net(net_input_noised * input_scale).float() logger.debug(self.net) diff --git a/tests/test_image_models.py b/tests/test_image_models.py index 2388f5c..19725ce 100644 --- a/tests/test_image_models.py +++ b/tests/test_image_models.py @@ -9,40 +9,35 @@ def test_dip_init(): - obj = DeepImagePrior(10, 10) + obj = DeepImagePrior(512, 512) assert obj def test_dip_update(): - obj = DeepImagePrior(10, 10) + obj = DeepImagePrior(512, 512) obj.update() def test_dip_forward(): - obj = DeepImagePrior(10, 10) + obj = DeepImagePrior(512, 512) obj.forward() def test_dip_decode_training_tensor(): - obj = DeepImagePrior(10, 10) + obj = DeepImagePrior(512, 512) obj.decode_training_tensor() def test_dip_decode_tensor(): - obj = DeepImagePrior(10, 10) + obj = DeepImagePrior(512, 512) obj.decode_tensor() def test_dip_clone(): - obj = DeepImagePrior(10, 10) + obj = DeepImagePrior(512, 512) obj.clone() def test_dip_get_latent_tensor(): obj = DeepImagePrior(10, 10) obj.get_latent_tensor() - - -def test_dip_get_image_tensor(): - obj = DeepImagePrior(512, 512) - obj.get_image_tensor() From be84527556f7d5dd32af0ee06f89d35d3e836ac0 Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 30 May 2022 21:22:34 -0700 Subject: [PATCH 17/29] works with basic params, needs pixel_size=1 and reencode_each_frame=false --- src/pytti/ImageGuide.py | 3 +++ src/pytti/Perceptor/Embedder.py | 6 ++++++ src/pytti/image_models/deep_image_prior.py | 11 ++++++++++- src/pytti/workhorse.py | 6 +++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 36122ed..ee81b82 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -292,6 +292,9 @@ def train( if self.embedder is not None: for mb_i in range(gradient_accumulation_steps): # logger.debug(mb_i) + # logger.debug(self.image_rep.shape) + logger.debug(type(self.image_rep)) + logger.debug(z.shape) image_embeds, offsets, sizes = self.embedder(self.image_rep, input=z) t = 1 diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index 74a4678..f3c329e 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -1,5 +1,7 @@ from typing import Tuple +from loguru import logger + import pytti from pytti import format_input, cat_with_pad, format_module, normalize @@ -131,6 +133,10 @@ def forward( mode=PADDING_MODES[self.border_mode], ) for cut_size, perceptor in zip(self.cut_sizes, perceptors): + logger.debug(f"cut_size: {cut_size}") # 224 + logger.debug(input.shape) # 1, 3, 512, 512 + logger.debug(side_x) # 2048 + logger.debug(side_y) # 2048 cutouts, offsets, sizes = self.make_cutouts(input, side_x, side_y, cut_size) clip_in = normalize(cutouts) image_embeds.append(perceptor.encode_image(clip_in).float().unsqueeze(0)) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index c5c3166..d84b8d7 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -100,7 +100,7 @@ def __init__( def decode_tensor(self): with torch.cuda.amp.autocast(): # out = net(net_input_noised * input_scale).float() - logger.debug(self.net) + # logger.debug(self.net) logger.debug(self._net_input.shape) out = self.net(self._net_input).float() logger.debug(out.shape) @@ -132,3 +132,12 @@ def clone(self): # dummy.decay = self.decay dummy = deepcopy(self) return dummy + + def encode_random(self): + pass + + @classmethod + def get_preferred_loss(cls): + from pytti.LossAug.LatentLossClass import LatentLoss + + return LatentLoss diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index d579831..e82ccb8 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -38,7 +38,7 @@ ) from pytti.rotoscoper import ROTOSCOPERS, get_frames -from pytti.image_models import PixelImage, RGBImage, VQGANImage +from pytti.image_models import PixelImage, RGBImage, VQGANImage, DeepImagePrior from pytti.ImageGuide import DirectImageGuide from pytti.Perceptor.Embedder import HDMultiClipEmbedder from pytti.Perceptor.Prompt import parse_prompt @@ -316,6 +316,7 @@ def do_run(): "Limited Palette", "Unlimited Palette", "VQGAN", + "Deep Image Prior", ) # set up image if params.image_model == "Limited Palette": @@ -349,6 +350,9 @@ def do_run(): params.width, params.height, params.pixel_size, device=_device ) img.encode_random() + elif params.image_model == "Deep Image Prior": + img = DeepImagePrior(params.width, params.height, params.pixel_size) + img.encode_random() else: logger.critical( "You should never see this message." From 63e9b4fa78f2f37011aca0566aad9ac120db1aad Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 30 May 2022 21:45:14 -0700 Subject: [PATCH 18/29] added encode_image for DIP --- src/pytti/image_models/deep_image_prior.py | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index d84b8d7..31f1d2c 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -2,6 +2,8 @@ from loguru import logger +from torch import optim + from pytti import clamp_with_grad import torch from torch import nn @@ -10,6 +12,8 @@ from PIL import Image from torch.nn import functional as F +from pytti.LossAug.MSELossClass import MSELoss + # scavenging code snippets from: # - https://github.com/LAION-AI/notebooks/blob/main/DALLE2-Prior%2BDeep-Image-Prior.ipynb @@ -141,3 +145,25 @@ def get_preferred_loss(cls): from pytti.LossAug.LatentLossClass import LatentLoss return LatentLoss + + def encode_image(self, pil_image, device="cuda"): + """ + Encodes the image into a tensor. + + :param pil_image: The image to encode + :param smart_encode: If True, the pallet will be optimized to match the image, defaults to True + (optional) + :param device: The device to run the model on + """ + width, height = self.image_shape + scale = self.scale + + mse = MSELoss.TargetImage("HSV loss", self.image_shape, pil_image) + + from pytti.ImageGuide import DirectImageGuide + + guide = DirectImageGuide( + self, None, optimizer=optim.Adam(self.get_latent_tensor()) + ) + # why is there a magic number here? + guide.run_steps(201, [], [], [mse]) From f89fe617e5ff3736e170454940416a9ac2bc56fe Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 30 May 2022 21:56:28 -0700 Subject: [PATCH 19/29] fixed DIP pixel_size scaling. --- src/pytti/image_models/deep_image_prior.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 31f1d2c..0867643 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -108,7 +108,10 @@ def decode_tensor(self): logger.debug(self._net_input.shape) out = self.net(self._net_input).float() logger.debug(out.shape) - return out + width, height = self.image_shape + out = F.interpolate(out, (height, width), mode="nearest") + return clamp_with_grad(out, 0, 1) + # return out def get_latent_tensor(self, detach=False): # pass From c3a3b47280071cb5557868cc6c687f9f42dd27c0 Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 30 May 2022 22:23:02 -0700 Subject: [PATCH 20/29] added 1/10 offset params lr scaling on DIP by default. --- src/pytti/image_models/deep_image_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 0867643..1ccac3d 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -62,7 +62,7 @@ def __init__( # offset_groups=1, disable_deformable_convolutions=False, lr=1e-3, - offset_lr_fac=1.0, + offset_lr_fac=0.1, # 1.0, ########### ema_val=0.99, ########### From c2e89842f82b33e2e9a0c1a9cb641085c5d9cb4c Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 31 May 2022 10:17:30 -0700 Subject: [PATCH 21/29] fixed default dip input depth --- src/pytti/image_models/deep_image_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 1ccac3d..68a1abe 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -56,7 +56,7 @@ def __init__( height, scale=1, ########### - input_depth=3, + input_depth=32, num_scales=7, offset_type="none", # offset_groups=1, From 0d4e555d54599e44a0842b9b023f08ccfab3eed1 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 10 Jun 2022 09:01:21 -0700 Subject: [PATCH 22/29] added dip tests --- tests/image_models/test_deep_image_prior.py | 43 +++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/image_models/test_deep_image_prior.py diff --git a/tests/image_models/test_deep_image_prior.py b/tests/image_models/test_deep_image_prior.py new file mode 100644 index 0000000..19725ce --- /dev/null +++ b/tests/image_models/test_deep_image_prior.py @@ -0,0 +1,43 @@ +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage +from pytti.image_models.deep_image_prior import DeepImagePrior + +### DIP ### + + +def test_dip_init(): + obj = DeepImagePrior(512, 512) + assert obj + + +def test_dip_update(): + obj = DeepImagePrior(512, 512) + obj.update() + + +def test_dip_forward(): + obj = DeepImagePrior(512, 512) + obj.forward() + + +def test_dip_decode_training_tensor(): + obj = DeepImagePrior(512, 512) + obj.decode_training_tensor() + + +def test_dip_decode_tensor(): + obj = DeepImagePrior(512, 512) + obj.decode_tensor() + + +def test_dip_clone(): + obj = DeepImagePrior(512, 512) + obj.clone() + + +def test_dip_get_latent_tensor(): + obj = DeepImagePrior(10, 10) + obj.get_latent_tensor() From ae35920f730a150828256e4939a51800841d68af Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 10 Jun 2022 09:04:38 -0700 Subject: [PATCH 23/29] added pipfile to sate the git gods --- Pipfile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Pipfile b/Pipfile index 881272d..5d2a2bf 100644 --- a/Pipfile +++ b/Pipfile @@ -9,6 +9,7 @@ verify_ssl = false name = "pytorch" [packages] +tensorflow = "==2.7.0" transformers = "==4.15.0" gdown = "===4.2.0" ftfy = "==6.0.3" @@ -30,6 +31,7 @@ jupyter = "*" imageio = "==2.4.1" PyGLM = "==2.5.7" adjustText = "*" + Pillow = "*" torch = "*" torchvision = "*" @@ -48,6 +50,9 @@ mmc = {git = "https://github.com/dmarx/Multi-Modal-Comparators"} pytest = "*" pre-commit = "*" click = "==8.0.4" +#Pillow = "==7.1.2" +#pyttitools-core = {editable = true, path = "."} +deep-image-prior = {path = "./deep-image-prior"} black = "*" [requires] From a6496c68fa7e2026f75dc6da96c5956c3d8120d4 Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 13 Jun 2022 05:27:06 -0700 Subject: [PATCH 24/29] incorporated new EMA into DIP --- src/pytti/image_models/deep_image_prior.py | 72 +++++++++++++++++----- src/pytti/image_models/ema.py | 15 +++-- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 68a1abe..4a904b1 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -8,7 +8,9 @@ import torch from torch import nn from torchvision.transforms import functional as TF -from pytti.image_models import DifferentiableImage, EMAImage + +# from pytti.image_models import DifferentiableImage +from pytti.image_models.ema import EMAImage, EMAParametersDict from PIL import Image from torch.nn import functional as F @@ -44,8 +46,8 @@ def load_dip(input_depth, num_scales, offset_type, offset_groups, device): return dip_net -# class DeepImagePrior(EMAImage): -class DeepImagePrior(DifferentiableImage): +class DeepImagePrior(EMAImage): + # class DeepImagePrior(DifferentiableImage): """ https://github.com/nousr/deep-image-prior/ """ @@ -69,7 +71,14 @@ def __init__( device="cuda", **kwargs, ): - super().__init__(width * scale, height * scale) + # super(super(EMAImage)).__init__() + nn.Module.__init__(self) + super().__init__( + width=width * scale, + height=height * scale, + decay=ema_val, + device=device, + ) net = load_dip( input_depth=input_depth, num_scales=num_scales, @@ -85,13 +94,13 @@ def __init__( # z = torch.cat(get_non_offset_params(net), get_offset_params(net)) # logger.debug(z.shape) # super().__init__(width * scale, height * scale, z, ema_val) - self.net = net + # self.net = net # self.tensor = self.net.params() self.output_axes = ("n", "s", "y", "x") self.scale = scale self.device = device - self._net_input = torch.randn([1, input_depth, width, height], device=device) + # self._net_input = torch.randn([1, input_depth, width, height], device=device) self.lr = lr self.offset_lr_fac = offset_lr_fac @@ -99,6 +108,24 @@ def __init__( # {'params': get_non_offset_params(net), 'lr': lr}, # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac} # ] + # z = { + # 'non_offset':get_non_offset_params(net), + # 'offset':get_offset_params(net), + # } + self.net = net + self._net_input = torch.randn([1, input_depth, width, height], device=device) + + self.image_representation_parameters = EMAParametersDict( + z=self.net, decay=ema_val, device=device + ) + + # super().__init__( + # width = width * scale, + # height = height * scale, + # tensor = z, + # decay = ema_val, + # device=device, + # ) # def get_image_tensor(self): def decode_tensor(self): @@ -129,17 +156,34 @@ def get_latent_tensor(self, detach=False): return params def clone(self): - # dummy = super().__init__(*self.image_shape) + # dummy = VQGANImage(*self.image_shape) # with torch.no_grad(): - # #dummy.tensor.set_(self.tensor.clone()) - # dummy.net.copy_(self.net) - # dummy.accum.set_(self.accum.clone()) - # dummy.biased.set_(self.biased.clone()) - # dummy.average.set_(self.average.clone()) - # dummy.decay = self.decay - dummy = deepcopy(self) + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + # dummy.accum.set_(self.accum.clone()) + # dummy.biased.set_(self.biased.clone()) + # dummy.average.set_(self.average.clone()) + # dummy.decay = self.decay + # return dummy + dummy = DeepImagePrior(*self.image_shape) + with torch.no_grad(): + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + dummy.image_representation_parameters.set_( + self.image_representation_parameters.clone() + ) return dummy + # def clone(self): + # # dummy = super().__init__(*self.image_shape) + # # with torch.no_grad(): + # # #dummy.tensor.set_(self.tensor.clone()) + # # dummy.net.copy_(self.net) + # # dummy.accum.set_(self.accum.clone()) + # # dummy.biased.set_(self.biased.clone()) + # # dummy.average.set_(self.average.clone()) + # # dummy.decay = self.decay + # dummy = deepcopy(self) + # return dummy + def encode_random(self): pass diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index 5e09a4f..df5c8e7 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -58,13 +58,15 @@ def reset(self): self.update() -class EMAParametersDict(ImageRepresentationalParameters): +# class EMAParametersDict(ImageRepresentationalParameters): +class EMAParametersDict(nn.Module): """ LatentTensor with a singleton dimension for the EMAParameters """ def __init__(self, z=None, decay=0.99, device=None): - super(ImageRepresentationalParameters).__init__() + # super(ImageRepresentationalParameters).__init__() + super().__init__() self.decay = decay if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -74,9 +76,10 @@ def __init__(self, z=None, decay=0.99, device=None): def _new(self, z=None): if z is None: # I think this can all go in the constructor and doesn't need to call .to() - z = torch.zeros(1, 3, self.height, self.width).to( - device=self.device, memory_format=torch.channels_last - ) + return nn.Parameter() + # z = torch.zeros(1, 3, self.height, self.width).to( + # device=self.device, memory_format=torch.channels_last + # ) # d_ = z d_ = {} if isinstance(z, EMAParametersDict): @@ -147,7 +150,7 @@ def reset(self): class EMAImage(DifferentiableImage): - def __init__(self, width, height, tensor, decay, device=None): + def __init__(self, width, height, tensor=None, decay=0.99, device=None): super().__init__(width=width, height=height, device=device) self.image_representation_parameters = EMAParametersDict( z=tensor, decay=decay, device=device From 4795e03b588cc2aff0dfbbf2bd94c89cffe98b83 Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 13 Jun 2022 05:46:28 -0700 Subject: [PATCH 25/29] suppress logging messages --- src/pytti/image_models/deep_image_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 4a904b1..885390d 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -132,9 +132,9 @@ def decode_tensor(self): with torch.cuda.amp.autocast(): # out = net(net_input_noised * input_scale).float() # logger.debug(self.net) - logger.debug(self._net_input.shape) + # logger.debug(self._net_input.shape) out = self.net(self._net_input).float() - logger.debug(out.shape) + # logger.debug(out.shape) width, height = self.image_shape out = F.interpolate(out, (height, width), mode="nearest") return clamp_with_grad(out, 0, 1) From 80bad6e5782cfb792ed692ee172cb1f0f46f98a5 Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 13 Jun 2022 15:57:26 -0700 Subject: [PATCH 26/29] shit is super broken. I think there's no more avoiding cleaning up how image models and losses work. --- src/pytti/LossAug/LatentLossClass.py | 122 +++++++++++++++++- src/pytti/LossAug/LossOrchestratorClass.py | 7 +- src/pytti/LossAug/MSELossClass.py | 35 ++++- src/pytti/image_models/deep_image_prior.py | 68 ++++++++-- .../image_models/differentiable_image.py | 9 +- src/pytti/image_models/ema.py | 11 ++ src/pytti/image_models/vqgan.py | 2 +- src/pytti/workhorse.py | 8 +- 8 files changed, 246 insertions(+), 16 deletions(-) diff --git a/src/pytti/LossAug/LatentLossClass.py b/src/pytti/LossAug/LatentLossClass.py index 8f267b8..e3e2146 100644 --- a/src/pytti/LossAug/LatentLossClass.py +++ b/src/pytti/LossAug/LatentLossClass.py @@ -6,8 +6,126 @@ import copy, re from pytti import DEVICE, fetch, parse, vram_usage_mode +from loguru import logger + class LatentLoss(MSELoss): + @torch.no_grad() + def __init__( + self, + comp, + weight=0.5, + stop=-math.inf, + name="direct target loss", + image_shape=None, + ): + super().__init__( + comp, weight, stop, name, image_shape + ) # this really should link back to the image model... + logger.debug(type(comp)) # inits to image tensor + self.pil_image = None + self.has_latent = False + w, h = image_shape + try: + comp_adjusted = TF.resize(comp.clone(), (h, w)) + except: + # comp_adjusted = comp.clone() + # Need to convert the latent to its image form + comp_adjusted = img_model.decode_tensor(comp.clone()) + self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape) + + @torch.no_grad() + def set_comp(self, pil_image, device=DEVICE): + logger.debug(type(pil_image)) + self.pil_image = pil_image + self.has_latent = False + im_resized = pil_image.resize( + self.image_shape, Image.LANCZOS + ) # to do: ResizeRight + # self.direct_loss.set_comp(im_resized) + self.direct_loss.set_comp(im_resized) + + @classmethod + def convert_input(cls, input, img): + logger.debug(type(input)) # pretty sure this is gonna be tensor + # return input # this is the default MSE loss version + return img.make_latent(input) + + @classmethod + def default_comp(cls, img_model=None, *args, **kargs): + logger.debug("default_comp") + logger.debug(type(img_model)) + device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + if img_model is None: + return torch.zeros(1, 1, 1, 1, device=device) + return img_model.default_comp(*args, **kargs) + + @classmethod + @vram_usage_mode("Latent Image Loss") + @torch.no_grad() + def TargetImage( + cls, + prompt_string, + image_shape, + pil_image=None, + is_path=False, + device=DEVICE, + img_model=None, + ): + logger.debug( + type(pil_image) + ) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image? + text, weight, stop = parse( + prompt_string, r"(? from target image constructor when no input image provided + if not self.has_latent: + # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation + # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image + latent = img.make_latent(self.pil_image) + logger.debug(type(latent)) # EMAParametersDict + logger.debug(type(self.comp)) # torch.Tensor + with torch.no_grad(): + self.comp.set_(latent.clone()) + self.has_latent = True + l1 = super().get_loss(img.get_latent_tensor(), img) / 2 + l2 = self.direct_loss.get_loss(input, img) / 10 + return l1 + l2 + + +###################################################################### + + +class LatentLossGeneric(LatentLoss): + # class LatentLoss(MSELoss): @torch.no_grad() def __init__( self, @@ -29,7 +147,9 @@ def __init__( def set_comp(self, pil_image, device=DEVICE): self.pil_image = pil_image self.has_latent = False - self.direct_loss.set_comp(pil_image.resize(self.image_shape, Image.LANCZOS)) + self.direct_loss.set_comp( + pil_image.resize(self.image_shape, Image.LANCZOS) + ) # to do: ResizeRight @classmethod @vram_usage_mode("Latent Image Loss") diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index 931537e..a9d45bb 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -27,8 +27,13 @@ def build_loss(weight_name, weight, name, img, pil_target): Loss = type(img).get_preferred_loss() else: Loss = LOSS_DICT[weight_name] + logger.debug(type(Loss)) + logger.debug(type(img)) out = Loss.TargetImage( - f"{weight_name} {name}:{weight}", img.image_shape, pil_target + f"{weight_name} {name}:{weight}", + img.image_shape, + pil_target, + img_model=img, # type(img) ) out.set_enabled(pil_target is not None) return out diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index ff1e5dc..10a2ecd 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -9,6 +9,8 @@ from pytti import fetch, parse, vram_usage_mode import torch +from loguru import logger + class MSELoss(Loss): @torch.no_grad() @@ -22,7 +24,16 @@ def __init__( device=None, ): super().__init__(weight, stop, name, device) - self.register_buffer("comp", comp) + logger.debug(type(comp)) + _comp = self.default_comp() if comp is None else comp + try: + self.register_buffer("comp", comp) + except TypeError: + logger.debug(type(comp)) + # _comp = self.default_comp() if comp is None else comp #comp._container #comp.image_representational_parameters #comp.default_comp() + logger.debug(type(_comp)) + self.register_module("comp", _comp) + # self.register_buffer("comp", _comp) if image_shape is None: height, width = comp.shape[-2:] image_shape = (width, height) @@ -30,12 +41,25 @@ def __init__( self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device)) self.use_mask = False + @classmethod + def default_comp(cls, img_model=None, *args, **kargs): + # logger.debug("default_comp") + # logger.debug(type(img_model)) + # device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + # if img_model is None: + # return torch.zeros(1, 1, 1, 1, device=device) + # return img_model.default_comp(*args, **kargs) + return torch.zeros(1, 1, 1, 1, device=device) + @classmethod @vram_usage_mode("Loss Augs") @torch.no_grad() def TargetImage( cls, prompt_string, image_shape, pil_image=None, is_path=False, device=None ): + """ + I guess this is like an alternate constructor for the class? + """ # Why is this prompt parsing stuff here? Deprecate in favor of centralized # parsing functions (if feasible) text, weight, stop = parse( @@ -87,10 +111,16 @@ def set_mask(self, mask, inverted=False, device=None): @classmethod def convert_input(cls, input, img): + """ + does nothing? + """ return input @classmethod def make_comp(cls, pil_image, device=None): + """ + looks like this just converts a PIL image to a tensor + """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") out = ( @@ -101,6 +131,9 @@ def make_comp(cls, pil_image, device=None): return cls.convert_input(out, None) def set_comp(self, pil_image, device=None): + """ + uses make_comp to convert a PIL image to a tensor, then assigns it to self.comp + """ if device is None: device = self.device self.comp.set_(type(self).make_comp(pil_image, device=device)) diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 885390d..303323f 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -69,6 +69,7 @@ def __init__( ema_val=0.99, ########### device="cuda", + image_encode_steps=30, # 500, # setting this low for prototyping. **kwargs, ): # super(super(EMAImage)).__init__() @@ -99,6 +100,7 @@ def __init__( self.output_axes = ("n", "s", "y", "x") self.scale = scale self.device = device + self.image_encode_steps = image_encode_steps # self._net_input = torch.randn([1, input_depth, width, height], device=device) @@ -115,6 +117,7 @@ def __init__( self.net = net self._net_input = torch.randn([1, input_depth, width, height], device=device) + # I think this is the attribute I want to use for "comp" in the latent loss self.image_representation_parameters = EMAParametersDict( z=self.net, decay=ema_val, device=device ) @@ -141,21 +144,27 @@ def decode_tensor(self): # return out def get_latent_tensor(self, detach=False): + # this will get used as the "comp" downstream # pass net = self.net lr = self.lr offset_lr_fac = self.offset_lr_fac - params = [ - {"params": get_non_offset_params(net), "lr": lr}, - {"params": get_offset_params(net), "lr": lr * offset_lr_fac}, - ] + # params = self.image_representation_parameters._container + # params = [ + # {"params": get_non_offset_params(net), "lr": lr}, + # {"params": get_offset_params(net), "lr": lr * offset_lr_fac}, + # ] # params = torch.cat( # get_non_offset_params(net), # get_offset_params(net) # ) - return params + # return params + # return self.net.params() + # return self.net.parameters() + # return self.image_representation_parameters # throws error from LatentLossClass.get_loss() --> self.comp.set_(latent.clone()) + return self.representation_parameters - def clone(self): + def clone(self) -> "DeepImagePrior": # dummy = VQGANImage(*self.image_shape) # with torch.no_grad(): # dummy.representation_parameters.set_(self.representation_parameters.clone()) @@ -170,7 +179,8 @@ def clone(self): dummy.image_representation_parameters.set_( self.image_representation_parameters.clone() ) - return dummy + return dummy # output of this function is expected to have an encode_image() method + # return dummy.image_representation_parameters # def clone(self): # # dummy = super().__init__(*self.image_shape) @@ -193,6 +203,34 @@ def get_preferred_loss(cls): return LatentLoss + def make_latent(self, pil_image): + """ + Takes a PIL image as input, + encodes it appropriately to the image representation (via .encode_image(pil_image)), + and returns the output of .get_latent_tensor(detach=True). + + NB: default behavior of .get_latent_tensor() is to just return the output of .get_image_tensor() + """ + try: + dummy = self.clone() + except NotImplementedError: + dummy = copy.deepcopy(self) + dummy.encode_image(pil_image) + # return dummy.get_latent_tensor(detach=True) + return dummy.image_representation_parameters + + @classmethod + def default_comp(*args, **kargs): + device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + net = load_dip( + input_depth=32, + num_scales=7, + offset_type="none", + offset_groups=4, + device=device, + ) + return EMAParametersDict(z=net, decay=0.99, device=device) + def encode_image(self, pil_image, device="cuda"): """ Encodes the image into a tensor. @@ -205,12 +243,22 @@ def encode_image(self, pil_image, device="cuda"): width, height = self.image_shape scale = self.scale - mse = MSELoss.TargetImage("HSV loss", self.image_shape, pil_image) + mse = MSELoss.TargetImage("MSE loss", self.image_shape, pil_image) from pytti.ImageGuide import DirectImageGuide + params = [ + {"params": get_non_offset_params(self.net), "lr": self.lr}, + {"params": get_offset_params(self.net), "lr": self.lr * self.offset_lr_fac}, + ] + guide = DirectImageGuide( - self, None, optimizer=optim.Adam(self.get_latent_tensor()) + self, + None, + optimizer=optim.Adam( + # self.get_latent_tensor() + params + ), ) # why is there a magic number here? - guide.run_steps(201, [], [], [mse]) + guide.run_steps(self.image_encode_steps, [], [], [mse]) diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 674130d..5a7b8f9 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -40,7 +40,7 @@ def get_image_tensor(self): """ raise NotImplementedError - def clone(self): + def clone(self) -> "DifferentiableImage": raise NotImplementedError def get_latent_tensor(self, detach=False): @@ -82,6 +82,13 @@ def update(self): pass def make_latent(self, pil_image): + """ + Takes a PIL image as input, + encodes it appropriately to the image representation (via .encode_image(pil_image)), + and returns the output of .get_latent_tensor(detach=True). + + NB: default behavior of .get_latent_tensor() is to just return the output of .get_image_tensor() + """ try: dummy = self.clone() except NotImplementedError: diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index df5c8e7..1033737 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -5,6 +5,8 @@ ImageRepresentationalParameters, ) +from loguru import logger + class EMATensor(nn.Module): """implmeneted by Katherine Crowson""" @@ -84,6 +86,7 @@ def _new(self, z=None): d_ = {} if isinstance(z, EMAParametersDict): for k, v in z.items(): + logger.debug(k) d_[k] = EMATensor(v.tensor, self.decay) elif isinstance(z, EMATensor): d_["z"] = z.clone() @@ -119,6 +122,9 @@ def _new(self, z=None): # d_ = {name: EMATensor(param.data, self.decay) for name, param in z.items()} # except AttributeError: # d_ = {name: EMATensor(param, self.decay) for name, param in z.items()} + # logger.debug(d_.keys()) + # d_ = torch.nn.ParameterDict(d_) + # d_ = torch.nn.ModuleDict(d_) return d_ def clone(self): @@ -135,9 +141,14 @@ def average(self): return {k: v.average for k, v in self._container.items()} def set_(self, d): + if isinstance(d, torch.Tensor): + logger.debug(self._container) + d_ = d if isinstance(d, EMAParametersDict): d_ = d._container + logger.debug(type(d_)) + logger.debug(d_.shape) # fuck it for k, v in d_.items(): self._container[k].set_(v) # self._container[k].tensor.set_(v) diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 11b2d8e..9fe6cfd 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -198,7 +198,7 @@ def __init__( # device=self.device, # ) - def clone(self): + def clone(self) -> "VQGANImage": # dummy = VQGANImage(*self.image_shape) # with torch.no_grad(): # dummy.representation_parameters.set_(self.representation_parameters.clone()) diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index e82ccb8..b50dae0 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -369,6 +369,8 @@ def do_run(): # set up init image # ##################### + logger.debug("configuring init image prompts") + ( init_augs, semantic_init_prompt, @@ -387,16 +389,18 @@ def do_run(): ) # other image prompts + logger.debug("configuring other image prompts") loss_augs.extend( type(img) .get_preferred_loss() - .TargetImage(p.strip(), img.image_shape, is_path=True) + .TargetImage(p.strip(), img.image_shape, is_path=True, img_model=type(img)) for p in params.direct_image_prompts.split("|") if p.strip() ) # stabilization + logger.debug("configuring stabilization losses") ( loss_augs, img, @@ -460,6 +464,8 @@ def do_run(): # img, # ) = loss_orch.configure_losses() + logger.debug("losses configured.") + # Phase 4 - setup outputs ########################## From 755f14b019bf862c9be2cbc1b7f11417c732b2b0 Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 14 Jun 2022 10:34:26 -0700 Subject: [PATCH 27/29] added note --- src/pytti/ImageGuide.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index ee81b82..d4c8d4f 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -320,6 +320,8 @@ def train( for prompt in prompts } + # oh.. uh... image_losses and auglosses don't actually depend on an embedder being attached. + # Maybe this is why limited palette wasn't initializing properly? losses, losses_raw = zip( *map(unpack_dict, [prompt_losses, aug_losses, image_losses]) # *map(unpack_dict, [prompt_losses]) From 02a983182f863c164668e7e545c2bb7ae97b2342 Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 14 Jun 2022 10:47:23 -0700 Subject: [PATCH 28/29] added note --- src/pytti/tensor_tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytti/tensor_tools.py b/src/pytti/tensor_tools.py index 6019c11..c3fb353 100644 --- a/src/pytti/tensor_tools.py +++ b/src/pytti/tensor_tools.py @@ -65,6 +65,7 @@ def format_module(module, dest, *args, **kwargs) -> torch.tensor: return format_input(output, module, dest) +# https://pytorch.org/docs/stable/autograd.html#function class ReplaceGrad(torch.autograd.Function): """ returns x_forward during forward pass, but evaluates derivates as though From 15520142dc0e4c37f92a46939674b5d16e723a7c Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 14 Jun 2022 15:28:02 -0700 Subject: [PATCH 29/29] h8 this. saving progress, but I think I need to just backtrack and simplify how losses and image_models work first, then come back to this afterwards. --- src/pytti/LossAug/LatentLossClass.py | 210 +++++++++++++++++++-- src/pytti/LossAug/LossOrchestratorClass.py | 2 +- src/pytti/LossAug/MSELossClass.py | 2 +- src/pytti/image_models/deep_image_prior.py | 30 ++- src/pytti/image_models/ema.py | 4 +- src/pytti/workhorse.py | 7 +- 6 files changed, 227 insertions(+), 28 deletions(-) diff --git a/src/pytti/LossAug/LatentLossClass.py b/src/pytti/LossAug/LatentLossClass.py index e3e2146..cc649f6 100644 --- a/src/pytti/LossAug/LatentLossClass.py +++ b/src/pytti/LossAug/LatentLossClass.py @@ -26,16 +26,20 @@ def __init__( self.pil_image = None self.has_latent = False w, h = image_shape - try: - comp_adjusted = TF.resize(comp.clone(), (h, w)) - except: - # comp_adjusted = comp.clone() - # Need to convert the latent to its image form - comp_adjusted = img_model.decode_tensor(comp.clone()) + comp_adjusted = TF.resize(comp.clone(), (h, w)) + # try: + # comp_adjusted = TF.resize(comp.clone(), (h, w)) + # except: + # # comp_adjusted = comp.clone() + # # Need to convert the latent to its image form + # comp_adjusted = img_model.decode_tensor(comp.clone()) self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape) @torch.no_grad() def set_comp(self, pil_image, device=DEVICE): + """ + sets the DIRECT loss anchor "comp" to the tensorized image. + """ logger.debug(type(pil_image)) self.pil_image = pil_image self.has_latent = False @@ -47,6 +51,10 @@ def set_comp(self, pil_image, device=DEVICE): @classmethod def convert_input(cls, input, img): + """ + Converts the input image tensor to the image representation of the image model. + E.g. if img is VQGAN, then the input tensor is converted to the latent representation. + """ logger.debug(type(input)) # pretty sure this is gonna be tensor # return input # this is the default MSE loss version return img.make_latent(input) @@ -107,6 +115,8 @@ def get_loss(self, input, img): logger.debug( self.comp.shape ) # [1 1 1 1] -> from target image constructor when no input image provided + + # why is the latent comp only set here? why not in the __init__ and set_comp? if not self.has_latent: # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image @@ -114,8 +124,12 @@ def get_loss(self, input, img): logger.debug(type(latent)) # EMAParametersDict logger.debug(type(self.comp)) # torch.Tensor with torch.no_grad(): - self.comp.set_(latent.clone()) + if type(latent) == type(self.comp): + self.comp.set_(latent.clone()) + # else: + self.has_latent = True + l1 = super().get_loss(img.get_latent_tensor(), img) / 2 l2 = self.direct_loss.get_loss(input, img) / 10 return l1 + l2 @@ -123,9 +137,40 @@ def get_loss(self, input, img): ###################################################################### +# fuck it, let's just make a dip latent loss from scratch. + + +# The issue we're resolving here is that by inheriting from the MSELoss, +# I can't easily set the comp to the parameters of the image model. + +from pytti.LossAug.BaseLossClass import Loss +from pytti.image_models.ema import EMAImage, EMAParametersDict +from pytti.rotoscoper import Rotoscoper + +import deep_image_prior +import deep_image_prior.models +from deep_image_prior.models import ( + get_hq_skip_net, + get_non_offset_params, + get_offset_params, +) -class LatentLossGeneric(LatentLoss): - # class LatentLoss(MSELoss): + +def load_dip(input_depth, num_scales, offset_type, offset_groups, device): + dip_net = get_hq_skip_net( + input_depth, + skip_n33d=192, + skip_n33u=192, + skip_n11=4, + num_scales=num_scales, + offset_type=offset_type, + offset_groups=offset_groups, + ).to(device) + + return dip_net + + +class LatentLossDIP(Loss): @torch.no_grad() def __init__( self, @@ -134,29 +179,109 @@ def __init__( stop=-math.inf, name="direct target loss", image_shape=None, + device=None, ): - super().__init__(comp, weight, stop, name, image_shape) + ################################################################## + super().__init__(weight, stop, name, device) + if image_shape is None: + raise + # height, width = comp.shape[-2:] + # image_shape = (width, height) + self.image_shape = image_shape + self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device)) + self.use_mask = False + ################################################################## self.pil_image = None self.has_latent = False - w, h = image_shape - self.direct_loss = MSELoss( - TF.resize(comp.clone(), (h, w)), weight, stop, name, image_shape + logger.debug(type(comp)) # inits to image tensor + if comp is None: + comp = self.default_comp() + if isinstance(comp, EMAParametersDict): + logger.debug("initializing loss from latent") + self.register_module("comp", comp) + self.has_latent = True + else: + w, h = image_shape + comp_adjusted = TF.resize(comp.clone(), (h, w)) + # try: + # comp_adjusted = TF.resize(comp.clone(), (h, w)) + # except: + # # comp_adjusted = comp.clone() + # # Need to convert the latent to its image form + # comp_adjusted = img_model.decode_tensor(comp.clone()) + self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape) + + ################################################################## + + logger.debug(type(comp)) + + @classmethod + def default_comp(*args, **kargs): + logger.debug("default_comp") + device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + net = load_dip( + input_depth=32, + num_scales=7, + offset_type="none", + offset_groups=4, + device=device, ) + return EMAParametersDict(z=net, decay=0.99, device=device) + + ################################################################################### @torch.no_grad() def set_comp(self, pil_image, device=DEVICE): + """ + sets the DIRECT loss anchor "comp" to the tensorized image. + """ + logger.debug(type(pil_image)) self.pil_image = pil_image self.has_latent = False - self.direct_loss.set_comp( - pil_image.resize(self.image_shape, Image.LANCZOS) + im_resized = pil_image.resize( + self.image_shape, Image.LANCZOS ) # to do: ResizeRight + # self.direct_loss.set_comp(im_resized) + + im_tensor = ( + TF.to_tensor(pil_image) + .unsqueeze(0) + .to(device, memory_format=torch.channels_last) + ) + + if hasattr(self, "direct_loss"): + self.direct_loss.set_comp(im_tensor) + else: + self.direct_loss = MSELoss( + im_tensor, self.weight, self.stop, self.name, self.image_shape + ) + # self.direct_loss.set_comp(im_resized) + + @classmethod + def convert_input(cls, input, img): + """ + Converts the input image tensor to the image representation of the image model. + E.g. if img is VQGAN, then the input tensor is converted to the latent representation. + """ + logger.debug(type(input)) # pretty sure this is gonna be tensor + # return input # this is the default MSE loss version + return img.make_latent(input) @classmethod @vram_usage_mode("Latent Image Loss") @torch.no_grad() def TargetImage( - cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE + cls, + prompt_string, + image_shape, + pil_image=None, + is_path=False, + device=DEVICE, + img_model=None, ): + logger.debug( + type(pil_image) + ) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image? text, weight, stop = parse( prompt_string, r"(? from target image constructor when no input image provided + + # why is the latent comp only set here? why not in the __init__ and set_comp? if not self.has_latent: + raise + # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation + # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image latent = img.make_latent(self.pil_image) + logger.debug(type(latent)) # EMAParametersDict + logger.debug(type(self.comp)) # torch.Tensor with torch.no_grad(): - self.comp.set_(latent.clone()) + if type(latent) == type(self.comp): + self.comp.set_(latent.clone()) + # else: + self.has_latent = True + + estimated_image = self.comp.get_image_tensor() + l1 = super().get_loss(img.get_latent_tensor(), img) / 2 l2 = self.direct_loss.get_loss(input, img) / 10 return l1 + l2 diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index a9d45bb..1b8973a 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -33,7 +33,7 @@ def build_loss(weight_name, weight, name, img, pil_target): f"{weight_name} {name}:{weight}", img.image_shape, pil_target, - img_model=img, # type(img) + # img_model=img, # type(img) ) out.set_enabled(pil_target is not None) return out diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index 10a2ecd..2aa21c0 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -43,7 +43,7 @@ def __init__( @classmethod def default_comp(cls, img_model=None, *args, **kargs): - # logger.debug("default_comp") + logger.debug("default_comp") # logger.debug(type(img_model)) # device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" # if img_model is None: diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py index 303323f..c4c3f02 100644 --- a/src/pytti/image_models/deep_image_prior.py +++ b/src/pytti/image_models/deep_image_prior.py @@ -131,7 +131,10 @@ def __init__( # ) # def get_image_tensor(self): - def decode_tensor(self): + def decode_tensor(self, input_latent=None): + """ + Generates the image tensor from the attached DIP representation + """ with torch.cuda.amp.autocast(): # out = net(net_input_noised * input_scale).float() # logger.debug(self.net) @@ -199,9 +202,12 @@ def encode_random(self): @classmethod def get_preferred_loss(cls): - from pytti.LossAug.LatentLossClass import LatentLoss + from pytti.LossAug.LatentLossClass import LatentLoss, LatentLossDIP + + return LatentLossDIP # LatentLoss - return LatentLoss + # it'll be stupid complicated, but I could put a closure in here... + # yeah no fuck that. I'm not adding complexity to enable deep image. I need to simplify how loss stuff works FIRST. def make_latent(self, pil_image): """ @@ -233,7 +239,7 @@ def default_comp(*args, **kargs): def encode_image(self, pil_image, device="cuda"): """ - Encodes the image into a tensor. + Fits the attached DIP model representation to the input pil_image. :param pil_image: The image to encode :param smart_encode: If True, the pallet will be optimized to match the image, defaults to True @@ -262,3 +268,19 @@ def encode_image(self, pil_image, device="cuda"): ) # why is there a magic number here? guide.run_steps(self.image_encode_steps, [], [], [mse]) + + +############################################################################################################################## + +# round three + +# gonna implement this the way that makes sense to me, and then see if I can't square-peg-round-hole it +class DipSimpleLatentLoss(nn.Module): + def __init__( + self, + net, + image_shape, + pil_image=None, + ): + super().__init__() + self.net = net diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index 1033737..1c182aa 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -143,12 +143,14 @@ def average(self): def set_(self, d): if isinstance(d, torch.Tensor): logger.debug(self._container) + logger.debug(d.shape) d_ = d if isinstance(d, EMAParametersDict): d_ = d._container logger.debug(type(d_)) - logger.debug(d_.shape) # fuck it + # logger.debug(d_.shape) # fuck it + logger.debug(type(self._container)) for k, v in d_.items(): self._container[k].set_(v) # self._container[k].tensor.set_(v) diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index b50dae0..43385d2 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -394,7 +394,12 @@ def do_run(): loss_augs.extend( type(img) .get_preferred_loss() - .TargetImage(p.strip(), img.image_shape, is_path=True, img_model=type(img)) + .TargetImage( + p.strip(), + img.image_shape, + is_path=True, + # img_model=type(img) + ) for p in params.direct_image_prompts.split("|") if p.strip() )