diff --git a/Pipfile b/Pipfile index 881272d..5d2a2bf 100644 --- a/Pipfile +++ b/Pipfile @@ -9,6 +9,7 @@ verify_ssl = false name = "pytorch" [packages] +tensorflow = "==2.7.0" transformers = "==4.15.0" gdown = "===4.2.0" ftfy = "==6.0.3" @@ -30,6 +31,7 @@ jupyter = "*" imageio = "==2.4.1" PyGLM = "==2.5.7" adjustText = "*" + Pillow = "*" torch = "*" torchvision = "*" @@ -48,6 +50,9 @@ mmc = {git = "https://github.com/dmarx/Multi-Modal-Comparators"} pytest = "*" pre-commit = "*" click = "==8.0.4" +#Pillow = "==7.1.2" +#pyttitools-core = {editable = true, path = "."} +deep-image-prior = {path = "./deep-image-prior"} black = "*" [requires] diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 36122ed..d4c8d4f 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -292,6 +292,9 @@ def train( if self.embedder is not None: for mb_i in range(gradient_accumulation_steps): # logger.debug(mb_i) + # logger.debug(self.image_rep.shape) + logger.debug(type(self.image_rep)) + logger.debug(z.shape) image_embeds, offsets, sizes = self.embedder(self.image_rep, input=z) t = 1 @@ -317,6 +320,8 @@ def train( for prompt in prompts } + # oh.. uh... image_losses and auglosses don't actually depend on an embedder being attached. + # Maybe this is why limited palette wasn't initializing properly? losses, losses_raw = zip( *map(unpack_dict, [prompt_losses, aug_losses, image_losses]) # *map(unpack_dict, [prompt_losses]) diff --git a/src/pytti/LossAug/LatentLossClass.py b/src/pytti/LossAug/LatentLossClass.py index 8f267b8..cc649f6 100644 --- a/src/pytti/LossAug/LatentLossClass.py +++ b/src/pytti/LossAug/LatentLossClass.py @@ -6,6 +6,8 @@ import copy, re from pytti import DEVICE, fetch, parse, vram_usage_mode +from loguru import logger + class LatentLoss(MSELoss): @torch.no_grad() @@ -17,26 +19,70 @@ def __init__( name="direct target loss", image_shape=None, ): - super().__init__(comp, weight, stop, name, image_shape) + super().__init__( + comp, weight, stop, name, image_shape + ) # this really should link back to the image model... + logger.debug(type(comp)) # inits to image tensor self.pil_image = None self.has_latent = False w, h = image_shape - self.direct_loss = MSELoss( - TF.resize(comp.clone(), (h, w)), weight, stop, name, image_shape - ) + comp_adjusted = TF.resize(comp.clone(), (h, w)) + # try: + # comp_adjusted = TF.resize(comp.clone(), (h, w)) + # except: + # # comp_adjusted = comp.clone() + # # Need to convert the latent to its image form + # comp_adjusted = img_model.decode_tensor(comp.clone()) + self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape) @torch.no_grad() def set_comp(self, pil_image, device=DEVICE): + """ + sets the DIRECT loss anchor "comp" to the tensorized image. + """ + logger.debug(type(pil_image)) self.pil_image = pil_image self.has_latent = False - self.direct_loss.set_comp(pil_image.resize(self.image_shape, Image.LANCZOS)) + im_resized = pil_image.resize( + self.image_shape, Image.LANCZOS + ) # to do: ResizeRight + # self.direct_loss.set_comp(im_resized) + self.direct_loss.set_comp(im_resized) + + @classmethod + def convert_input(cls, input, img): + """ + Converts the input image tensor to the image representation of the image model. + E.g. if img is VQGAN, then the input tensor is converted to the latent representation. + """ + logger.debug(type(input)) # pretty sure this is gonna be tensor + # return input # this is the default MSE loss version + return img.make_latent(input) + + @classmethod + def default_comp(cls, img_model=None, *args, **kargs): + logger.debug("default_comp") + logger.debug(type(img_model)) + device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + if img_model is None: + return torch.zeros(1, 1, 1, 1, device=device) + return img_model.default_comp(*args, **kargs) @classmethod @vram_usage_mode("Latent Image Loss") @torch.no_grad() def TargetImage( - cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE + cls, + prompt_string, + image_shape, + pil_image=None, + is_path=False, + device=DEVICE, + img_model=None, ): + logger.debug( + type(pil_image) + ) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image? text, weight, stop = parse( prompt_string, r"(? from target image constructor when no input image provided + + # why is the latent comp only set here? why not in the __init__ and set_comp? if not self.has_latent: + # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation + # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image latent = img.make_latent(self.pil_image) + logger.debug(type(latent)) # EMAParametersDict + logger.debug(type(self.comp)) # torch.Tensor with torch.no_grad(): - self.comp.set_(latent.clone()) + if type(latent) == type(self.comp): + self.comp.set_(latent.clone()) + # else: + + self.has_latent = True + + l1 = super().get_loss(img.get_latent_tensor(), img) / 2 + l2 = self.direct_loss.get_loss(input, img) / 10 + return l1 + l2 + + +###################################################################### + +# fuck it, let's just make a dip latent loss from scratch. + + +# The issue we're resolving here is that by inheriting from the MSELoss, +# I can't easily set the comp to the parameters of the image model. + +from pytti.LossAug.BaseLossClass import Loss +from pytti.image_models.ema import EMAImage, EMAParametersDict +from pytti.rotoscoper import Rotoscoper + +import deep_image_prior +import deep_image_prior.models +from deep_image_prior.models import ( + get_hq_skip_net, + get_non_offset_params, + get_offset_params, +) + + +def load_dip(input_depth, num_scales, offset_type, offset_groups, device): + dip_net = get_hq_skip_net( + input_depth, + skip_n33d=192, + skip_n33u=192, + skip_n11=4, + num_scales=num_scales, + offset_type=offset_type, + offset_groups=offset_groups, + ).to(device) + + return dip_net + + +class LatentLossDIP(Loss): + @torch.no_grad() + def __init__( + self, + comp, + weight=0.5, + stop=-math.inf, + name="direct target loss", + image_shape=None, + device=None, + ): + ################################################################## + super().__init__(weight, stop, name, device) + if image_shape is None: + raise + # height, width = comp.shape[-2:] + # image_shape = (width, height) + self.image_shape = image_shape + self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device)) + self.use_mask = False + ################################################################## + self.pil_image = None + self.has_latent = False + logger.debug(type(comp)) # inits to image tensor + if comp is None: + comp = self.default_comp() + if isinstance(comp, EMAParametersDict): + logger.debug("initializing loss from latent") + self.register_module("comp", comp) self.has_latent = True + else: + w, h = image_shape + comp_adjusted = TF.resize(comp.clone(), (h, w)) + # try: + # comp_adjusted = TF.resize(comp.clone(), (h, w)) + # except: + # # comp_adjusted = comp.clone() + # # Need to convert the latent to its image form + # comp_adjusted = img_model.decode_tensor(comp.clone()) + self.direct_loss = MSELoss(comp_adjusted, weight, stop, name, image_shape) + + ################################################################## + + logger.debug(type(comp)) + + @classmethod + def default_comp(*args, **kargs): + logger.debug("default_comp") + device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + net = load_dip( + input_depth=32, + num_scales=7, + offset_type="none", + offset_groups=4, + device=device, + ) + return EMAParametersDict(z=net, decay=0.99, device=device) + + ################################################################################### + + @torch.no_grad() + def set_comp(self, pil_image, device=DEVICE): + """ + sets the DIRECT loss anchor "comp" to the tensorized image. + """ + logger.debug(type(pil_image)) + self.pil_image = pil_image + self.has_latent = False + im_resized = pil_image.resize( + self.image_shape, Image.LANCZOS + ) # to do: ResizeRight + # self.direct_loss.set_comp(im_resized) + + im_tensor = ( + TF.to_tensor(pil_image) + .unsqueeze(0) + .to(device, memory_format=torch.channels_last) + ) + + if hasattr(self, "direct_loss"): + self.direct_loss.set_comp(im_tensor) + else: + self.direct_loss = MSELoss( + im_tensor, self.weight, self.stop, self.name, self.image_shape + ) + # self.direct_loss.set_comp(im_resized) + + @classmethod + def convert_input(cls, input, img): + """ + Converts the input image tensor to the image representation of the image model. + E.g. if img is VQGAN, then the input tensor is converted to the latent representation. + """ + logger.debug(type(input)) # pretty sure this is gonna be tensor + # return input # this is the default MSE loss version + return img.make_latent(input) + + @classmethod + @vram_usage_mode("Latent Image Loss") + @torch.no_grad() + def TargetImage( + cls, + prompt_string, + image_shape, + pil_image=None, + is_path=False, + device=DEVICE, + img_model=None, + ): + logger.debug( + type(pil_image) + ) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image? + text, weight, stop = parse( + prompt_string, r"(? from target image constructor when no input image provided + + # why is the latent comp only set here? why not in the __init__ and set_comp? + if not self.has_latent: + raise + # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation + # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image + latent = img.make_latent(self.pil_image) + logger.debug(type(latent)) # EMAParametersDict + logger.debug(type(self.comp)) # torch.Tensor + with torch.no_grad(): + if type(latent) == type(self.comp): + self.comp.set_(latent.clone()) + # else: + + self.has_latent = True + + estimated_image = self.comp.get_image_tensor() + l1 = super().get_loss(img.get_latent_tensor(), img) / 2 l2 = self.direct_loss.get_loss(input, img) / 10 return l1 + l2 diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index 931537e..1b8973a 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -27,8 +27,13 @@ def build_loss(weight_name, weight, name, img, pil_target): Loss = type(img).get_preferred_loss() else: Loss = LOSS_DICT[weight_name] + logger.debug(type(Loss)) + logger.debug(type(img)) out = Loss.TargetImage( - f"{weight_name} {name}:{weight}", img.image_shape, pil_target + f"{weight_name} {name}:{weight}", + img.image_shape, + pil_target, + # img_model=img, # type(img) ) out.set_enabled(pil_target is not None) return out diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index ff1e5dc..2aa21c0 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -9,6 +9,8 @@ from pytti import fetch, parse, vram_usage_mode import torch +from loguru import logger + class MSELoss(Loss): @torch.no_grad() @@ -22,7 +24,16 @@ def __init__( device=None, ): super().__init__(weight, stop, name, device) - self.register_buffer("comp", comp) + logger.debug(type(comp)) + _comp = self.default_comp() if comp is None else comp + try: + self.register_buffer("comp", comp) + except TypeError: + logger.debug(type(comp)) + # _comp = self.default_comp() if comp is None else comp #comp._container #comp.image_representational_parameters #comp.default_comp() + logger.debug(type(_comp)) + self.register_module("comp", _comp) + # self.register_buffer("comp", _comp) if image_shape is None: height, width = comp.shape[-2:] image_shape = (width, height) @@ -30,12 +41,25 @@ def __init__( self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device)) self.use_mask = False + @classmethod + def default_comp(cls, img_model=None, *args, **kargs): + logger.debug("default_comp") + # logger.debug(type(img_model)) + # device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + # if img_model is None: + # return torch.zeros(1, 1, 1, 1, device=device) + # return img_model.default_comp(*args, **kargs) + return torch.zeros(1, 1, 1, 1, device=device) + @classmethod @vram_usage_mode("Loss Augs") @torch.no_grad() def TargetImage( cls, prompt_string, image_shape, pil_image=None, is_path=False, device=None ): + """ + I guess this is like an alternate constructor for the class? + """ # Why is this prompt parsing stuff here? Deprecate in favor of centralized # parsing functions (if feasible) text, weight, stop = parse( @@ -87,10 +111,16 @@ def set_mask(self, mask, inverted=False, device=None): @classmethod def convert_input(cls, input, img): + """ + does nothing? + """ return input @classmethod def make_comp(cls, pil_image, device=None): + """ + looks like this just converts a PIL image to a tensor + """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") out = ( @@ -101,6 +131,9 @@ def make_comp(cls, pil_image, device=None): return cls.convert_input(out, None) def set_comp(self, pil_image, device=None): + """ + uses make_comp to convert a PIL image to a tensor, then assigns it to self.comp + """ if device is None: device = self.device self.comp.set_(type(self).make_comp(pil_image, device=device)) diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index 74a4678..f3c329e 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -1,5 +1,7 @@ from typing import Tuple +from loguru import logger + import pytti from pytti import format_input, cat_with_pad, format_module, normalize @@ -131,6 +133,10 @@ def forward( mode=PADDING_MODES[self.border_mode], ) for cut_size, perceptor in zip(self.cut_sizes, perceptors): + logger.debug(f"cut_size: {cut_size}") # 224 + logger.debug(input.shape) # 1, 3, 512, 512 + logger.debug(side_x) # 2048 + logger.debug(side_y) # 2048 cutouts, offsets, sizes = self.make_cutouts(input, side_x, side_y, cut_size) clip_in = normalize(cutouts) image_embeds.append(perceptor.encode_image(clip_in).float().unsqueeze(0)) diff --git a/src/pytti/image_models/__init__.py b/src/pytti/image_models/__init__.py index b47a344..547b44d 100644 --- a/src/pytti/image_models/__init__.py +++ b/src/pytti/image_models/__init__.py @@ -9,3 +9,4 @@ from .pixel import PixelImage from .rgb_image import RGBImage from .vqgan import VQGANImage +from .deep_image_prior import DeepImagePrior diff --git a/src/pytti/image_models/deep_image_prior.py b/src/pytti/image_models/deep_image_prior.py new file mode 100644 index 0000000..c4c3f02 --- /dev/null +++ b/src/pytti/image_models/deep_image_prior.py @@ -0,0 +1,286 @@ +from copy import deepcopy + +from loguru import logger + +from torch import optim + +from pytti import clamp_with_grad +import torch +from torch import nn +from torchvision.transforms import functional as TF + +# from pytti.image_models import DifferentiableImage +from pytti.image_models.ema import EMAImage, EMAParametersDict +from PIL import Image +from torch.nn import functional as F + +from pytti.LossAug.MSELossClass import MSELoss + +# scavenging code snippets from: +# - https://github.com/LAION-AI/notebooks/blob/main/DALLE2-Prior%2BDeep-Image-Prior.ipynb + +import deep_image_prior +import deep_image_prior.models +from deep_image_prior.models import ( + get_hq_skip_net, + get_non_offset_params, + get_offset_params, +) + +# from deep_image_prior import get_hq_skip_net, get_non_offset_params, get_offset_params + +# foo = deep_image_prior.models + + +def load_dip(input_depth, num_scales, offset_type, offset_groups, device): + dip_net = get_hq_skip_net( + input_depth, + skip_n33d=192, + skip_n33u=192, + skip_n11=4, + num_scales=num_scales, + offset_type=offset_type, + offset_groups=offset_groups, + ).to(device) + + return dip_net + + +class DeepImagePrior(EMAImage): + # class DeepImagePrior(DifferentiableImage): + """ + https://github.com/nousr/deep-image-prior/ + """ + + def __init__( + self, + width, + height, + scale=1, + ########### + input_depth=32, + num_scales=7, + offset_type="none", + # offset_groups=1, + disable_deformable_convolutions=False, + lr=1e-3, + offset_lr_fac=0.1, # 1.0, + ########### + ema_val=0.99, + ########### + device="cuda", + image_encode_steps=30, # 500, # setting this low for prototyping. + **kwargs, + ): + # super(super(EMAImage)).__init__() + nn.Module.__init__(self) + super().__init__( + width=width * scale, + height=height * scale, + decay=ema_val, + device=device, + ) + net = load_dip( + input_depth=input_depth, + num_scales=num_scales, + offset_type=offset_type, + offset_groups=0 if disable_deformable_convolutions else 4, + device=device, + ) + # z = self.get_latent_tensor() + # params = [ + # {'params': get_non_offset_params(net), 'lr': lr}, + # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac} + # ] + # z = torch.cat(get_non_offset_params(net), get_offset_params(net)) + # logger.debug(z.shape) + # super().__init__(width * scale, height * scale, z, ema_val) + # self.net = net + # self.tensor = self.net.params() + self.output_axes = ("n", "s", "y", "x") + self.scale = scale + self.device = device + self.image_encode_steps = image_encode_steps + + # self._net_input = torch.randn([1, input_depth, width, height], device=device) + + self.lr = lr + self.offset_lr_fac = offset_lr_fac + # self._params = [ + # {'params': get_non_offset_params(net), 'lr': lr}, + # {'params': get_offset_params(net), 'lr': lr * offset_lr_fac} + # ] + # z = { + # 'non_offset':get_non_offset_params(net), + # 'offset':get_offset_params(net), + # } + self.net = net + self._net_input = torch.randn([1, input_depth, width, height], device=device) + + # I think this is the attribute I want to use for "comp" in the latent loss + self.image_representation_parameters = EMAParametersDict( + z=self.net, decay=ema_val, device=device + ) + + # super().__init__( + # width = width * scale, + # height = height * scale, + # tensor = z, + # decay = ema_val, + # device=device, + # ) + + # def get_image_tensor(self): + def decode_tensor(self, input_latent=None): + """ + Generates the image tensor from the attached DIP representation + """ + with torch.cuda.amp.autocast(): + # out = net(net_input_noised * input_scale).float() + # logger.debug(self.net) + # logger.debug(self._net_input.shape) + out = self.net(self._net_input).float() + # logger.debug(out.shape) + width, height = self.image_shape + out = F.interpolate(out, (height, width), mode="nearest") + return clamp_with_grad(out, 0, 1) + # return out + + def get_latent_tensor(self, detach=False): + # this will get used as the "comp" downstream + # pass + net = self.net + lr = self.lr + offset_lr_fac = self.offset_lr_fac + # params = self.image_representation_parameters._container + # params = [ + # {"params": get_non_offset_params(net), "lr": lr}, + # {"params": get_offset_params(net), "lr": lr * offset_lr_fac}, + # ] + # params = torch.cat( + # get_non_offset_params(net), + # get_offset_params(net) + # ) + # return params + # return self.net.params() + # return self.net.parameters() + # return self.image_representation_parameters # throws error from LatentLossClass.get_loss() --> self.comp.set_(latent.clone()) + return self.representation_parameters + + def clone(self) -> "DeepImagePrior": + # dummy = VQGANImage(*self.image_shape) + # with torch.no_grad(): + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + # dummy.accum.set_(self.accum.clone()) + # dummy.biased.set_(self.biased.clone()) + # dummy.average.set_(self.average.clone()) + # dummy.decay = self.decay + # return dummy + dummy = DeepImagePrior(*self.image_shape) + with torch.no_grad(): + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + dummy.image_representation_parameters.set_( + self.image_representation_parameters.clone() + ) + return dummy # output of this function is expected to have an encode_image() method + # return dummy.image_representation_parameters + + # def clone(self): + # # dummy = super().__init__(*self.image_shape) + # # with torch.no_grad(): + # # #dummy.tensor.set_(self.tensor.clone()) + # # dummy.net.copy_(self.net) + # # dummy.accum.set_(self.accum.clone()) + # # dummy.biased.set_(self.biased.clone()) + # # dummy.average.set_(self.average.clone()) + # # dummy.decay = self.decay + # dummy = deepcopy(self) + # return dummy + + def encode_random(self): + pass + + @classmethod + def get_preferred_loss(cls): + from pytti.LossAug.LatentLossClass import LatentLoss, LatentLossDIP + + return LatentLossDIP # LatentLoss + + # it'll be stupid complicated, but I could put a closure in here... + # yeah no fuck that. I'm not adding complexity to enable deep image. I need to simplify how loss stuff works FIRST. + + def make_latent(self, pil_image): + """ + Takes a PIL image as input, + encodes it appropriately to the image representation (via .encode_image(pil_image)), + and returns the output of .get_latent_tensor(detach=True). + + NB: default behavior of .get_latent_tensor() is to just return the output of .get_image_tensor() + """ + try: + dummy = self.clone() + except NotImplementedError: + dummy = copy.deepcopy(self) + dummy.encode_image(pil_image) + # return dummy.get_latent_tensor(detach=True) + return dummy.image_representation_parameters + + @classmethod + def default_comp(*args, **kargs): + device = kargs.get("device", "cuda") if torch.cuda.is_available() else "cpu" + net = load_dip( + input_depth=32, + num_scales=7, + offset_type="none", + offset_groups=4, + device=device, + ) + return EMAParametersDict(z=net, decay=0.99, device=device) + + def encode_image(self, pil_image, device="cuda"): + """ + Fits the attached DIP model representation to the input pil_image. + + :param pil_image: The image to encode + :param smart_encode: If True, the pallet will be optimized to match the image, defaults to True + (optional) + :param device: The device to run the model on + """ + width, height = self.image_shape + scale = self.scale + + mse = MSELoss.TargetImage("MSE loss", self.image_shape, pil_image) + + from pytti.ImageGuide import DirectImageGuide + + params = [ + {"params": get_non_offset_params(self.net), "lr": self.lr}, + {"params": get_offset_params(self.net), "lr": self.lr * self.offset_lr_fac}, + ] + + guide = DirectImageGuide( + self, + None, + optimizer=optim.Adam( + # self.get_latent_tensor() + params + ), + ) + # why is there a magic number here? + guide.run_steps(self.image_encode_steps, [], [], [mse]) + + +############################################################################################################################## + +# round three + +# gonna implement this the way that makes sense to me, and then see if I can't square-peg-round-hole it +class DipSimpleLatentLoss(nn.Module): + def __init__( + self, + net, + image_shape, + pil_image=None, + ): + super().__init__() + self.net = net diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 79ea590..5a7b8f9 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -1,29 +1,31 @@ import copy + +import torch from torch import nn import numpy as np from PIL import Image from pytti.tensor_tools import named_rearrange -SUPPORTED_MODES = ["L", "RGB", "I", "F"] - class DifferentiableImage(nn.Module): """ Base class for defining differentiable images width: (positive integer) image width in pixels height: (positive integer) image height in pixels - pixel_format: (string) PIL image mode. Either 'L','RGB','I', or 'F' """ - def __init__(self, width: int, height: int, pixel_format: str = "RGB"): + def __init__(self, width: int, height: int, device=None): super().__init__() - if pixel_format not in SUPPORTED_MODES: - raise ValueError(f"Pixel format {pixel_format} is not supported.") self.image_shape = (width, height) - self.pixel_format = format self.output_axes = ("x", "y", "s") self.lr = 0.02 self.latent_strength = 0 + self.image_representation_parameters = ImageRepresentationalParameters( + width=width, height=height + ) + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # self.tensor = self.image_representation_parameters._new() def decode_training_tensor(self): """ @@ -38,7 +40,7 @@ def get_image_tensor(self): """ raise NotImplementedError - def clone(self): + def clone(self) -> "DifferentiableImage": raise NotImplementedError def get_latent_tensor(self, detach=False): @@ -80,6 +82,13 @@ def update(self): pass def make_latent(self, pil_image): + """ + Takes a PIL image as input, + encodes it appropriately to the image representation (via .encode_image(pil_image)), + and returns the output of .get_latent_tensor(detach=True). + + NB: default behavior of .get_latent_tensor() is to just return the output of .get_image_tensor() + """ try: dummy = self.clone() except NotImplementedError: @@ -120,3 +129,51 @@ def forward(self): return self.decode_training_tensor() else: return self.decode_tensor() + + @property + def representation_parameters(self): + return self.image_representation_parameters._container + + ## yeah I should really make this class an ABC + # if not hasattr(self, "representation_parameters"): + # raise NotImplementedError + # return self.tensor + + +class ImageRepresentationalParameters(nn.Module): + """ + Base class for defining parameters of differentiable images + width: (positive integer) image width in pixels + height: (positive integer) image height in pixels + """ + + def __init__(self, width: int, height: int, z=None, device=None): + super().__init__() + self.width = width + self.height = height + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + self._container = self._new(z) + + def _new(self, z=None): + if z is None: + # I think this can all go in the constructor and doesn't need to call .to() + z = torch.zeros(1, 3, self.height, self.width).to( + device=self.device, memory_format=torch.channels_last + ) + return nn.Parameter(z) + + +# class LatentTensor(ImageRepresentationalParameters): +# pass +# def __init__(self, z, device=None): +# super().__init__(z.shape[1], z.shape[2], device=device) +# #self._container = z +# self._z = z +# def _new(self): +# return nn.Parameter( +# torch.zeros(1, 3, height, width).to( +# device=self.device, memory_format=torch.channels_last +# ) +# ) diff --git a/src/pytti/image_models/ema.py b/src/pytti/image_models/ema.py index ea04791..1c182aa 100644 --- a/src/pytti/image_models/ema.py +++ b/src/pytti/image_models/ema.py @@ -1,17 +1,190 @@ import torch from torch import nn -from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.differentiable_image import ( + DifferentiableImage, + ImageRepresentationalParameters, +) + +from loguru import logger + + +class EMATensor(nn.Module): + """implmeneted by Katherine Crowson""" + + def __init__(self, tensor, decay): + super().__init__() + self.tensor = nn.Parameter(tensor) + self.register_buffer("biased", torch.zeros_like(tensor)) + self.register_buffer("average", torch.zeros_like(tensor)) + self.decay = decay + self.register_buffer("accum", torch.tensor(1.0)) + self.update() + + @torch.no_grad() + def update(self): + if not self.training: + raise RuntimeError("update() should only be called during training") + + self.accum *= self.decay + self.biased.mul_(self.decay) + self.biased.add_((1 - self.decay) * self.tensor) + self.average.copy_(self.biased) + self.average.div_(1 - self.accum) + + def forward(self): + if self.training: + return self.tensor + return self.average + + def clone(self): + new = EMATensor(self.tensor.clone(), self.decay) + new.accum.copy_(self.accum) + new.biased.copy_(self.biased) + new.average.copy_(self.average) + return new + + def set_(self, other): + self.tensor.set_(other.tensor) + self.accum.set_(other.accum) + self.biased.set_(other.biased) + self.average.set_(other.average) + # self.update() + + @torch.no_grad() + def reset(self): + if not self.training: + raise RuntimeError("reset() should only be called during training") + self.biased.set_(torch.zeros_like(self.biased)) + self.average.set_(torch.zeros_like(self.average)) + self.accum.set_(torch.ones_like(self.accum)) + self.update() + + +# class EMAParametersDict(ImageRepresentationalParameters): +class EMAParametersDict(nn.Module): + """ + LatentTensor with a singleton dimension for the EMAParameters + """ + + def __init__(self, z=None, decay=0.99, device=None): + # super(ImageRepresentationalParameters).__init__() + super().__init__() + self.decay = decay + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + self._container = self._new(z) + + def _new(self, z=None): + if z is None: + # I think this can all go in the constructor and doesn't need to call .to() + return nn.Parameter() + # z = torch.zeros(1, 3, self.height, self.width).to( + # device=self.device, memory_format=torch.channels_last + # ) + # d_ = z + d_ = {} + if isinstance(z, EMAParametersDict): + for k, v in z.items(): + logger.debug(k) + d_[k] = EMATensor(v.tensor, self.decay) + elif isinstance(z, EMATensor): + d_["z"] = z.clone() + elif hasattr(z, "named_parameters"): + d_ = { + name: EMATensor(param, decay=self.decay) + for name, param in z.named_parameters() + } + elif isinstance(z, dict): + print(type(z)) + for k, v in z.items(): + # print(k) + # print(type(v)) + # d_[k] = EMATensor(v, self.decay) + # d_[k] = EMATensor(v.clone(), self.decay) + if isinstance(v, EMATensor): + d_[k] = v.clone() + elif isinstance(z, torch.Tensor): + d_["z"] = EMATensor(z, decay=self.decay) + else: + raise ValueError( + "z must be a dict, torch.nn.Module, EMATensor, EMAParametersDict, or torch.Tensor" + ) + + # if not isinstance(z, dict): + ##if hasattr(z, "named_parameters"): + ## d_ = {name: EMATensor(param, decay=self.decay) for name, param in z.named_parameters()} + # else: + # d_['z'] = EMATensor(z, decay=self.decay) + # else: + # #d_ = {name: EMATensor(param, self.decay) for name, param in z.items()} + # try: + # d_ = {name: EMATensor(param.data, self.decay) for name, param in z.items()} + # except AttributeError: + # d_ = {name: EMATensor(param, self.decay) for name, param in z.items()} + # logger.debug(d_.keys()) + # d_ = torch.nn.ParameterDict(d_) + # d_ = torch.nn.ModuleDict(d_) + return d_ + + def clone(self): + # d_ = {k: v.clone() for k, v in self._container.items()} + d_ = {k: v.clone() for k, v in self._container.items()} + return EMAParametersDict(z=d_, decay=self.decay, device=self.device) + + def update(self): + for param in self._container.values(): + param.update() + + @property + def average(self): + return {k: v.average for k, v in self._container.items()} + + def set_(self, d): + if isinstance(d, torch.Tensor): + logger.debug(self._container) + logger.debug(d.shape) + + d_ = d + if isinstance(d, EMAParametersDict): + d_ = d._container + logger.debug(type(d_)) + # logger.debug(d_.shape) # fuck it + logger.debug(type(self._container)) + for k, v in d_.items(): + self._container[k].set_(v) + # self._container[k].tensor.set_(v) + # self._container[k].tensor.set_(v.tensor) + # self._container[k].tensor.set_(v.tensor) + + def reset(self): + for param in self._container.values(): + param.reset() class EMAImage(DifferentiableImage): + def __init__(self, width, height, tensor=None, decay=0.99, device=None): + super().__init__(width=width, height=height, device=device) + self.image_representation_parameters = EMAParametersDict( + z=tensor, decay=decay, device=device + ) + + +class LatentTensor(EMAImage): + pass + + +class EMAImage_old(DifferentiableImage): """ Base class for differentiable images with Exponential Moving Average filtering Based on code by Katherine Crowson """ def __init__(self, width, height, tensor, decay): - super().__init__(width, height) - self.tensor = nn.Parameter(tensor) + # super().__init__(width, height) + super().__init__(width=width, height=height, z=tensor) + # self.representation_parameters = nn.Parameter(tensor) + # self.image_representation_parameters._container = nn.Parameter(tensor) self.register_buffer("biased", torch.zeros_like(tensor)) self.register_buffer("average", torch.zeros_like(tensor)) self.decay = decay @@ -24,7 +197,7 @@ def update(self): raise RuntimeError("update() should only be called during training") self.accum.mul_(self.decay) self.biased.mul_(self.decay) - self.biased.add_((1 - self.decay) * self.tensor) + self.biased.add_((1 - self.decay) * self.representation_parameters) self.average.copy_(self.biased) self.average.div_(1 - self.accum) @@ -38,7 +211,7 @@ def reset(self): self.update() def decode_training_tensor(self): - return self.decode(self.tensor) + return self.decode(self.representation_parameters) def decode_tensor(self): return self.decode(self.average) diff --git a/src/pytti/image_models/rgb_image.py b/src/pytti/image_models/rgb_image.py index e6b39a6..cc02d22 100644 --- a/src/pytti/image_models/rgb_image.py +++ b/src/pytti/image_models/rgb_image.py @@ -6,7 +6,7 @@ from PIL import Image from torch.nn import functional as F - +# why doesn't this inherit from EMA? class RGBImage(DifferentiableImage): """ Naive RGB image representation diff --git a/src/pytti/image_models/vqgan.py b/src/pytti/image_models/vqgan.py index 29210cc..9fe6cfd 100644 --- a/src/pytti/image_models/vqgan.py +++ b/src/pytti/image_models/vqgan.py @@ -11,7 +11,11 @@ from pytti import replace_grad, clamp_with_grad, vram_usage_mode import torch from torch.nn import functional as F -from pytti.image_models import EMAImage + +# from pytti.image_models import EMAImage +# from pytti.image_models.differentiable_image import LatentTensor +# from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage from torchvision.transforms import functional as TF from PIL import Image from omegaconf import OmegaConf @@ -185,20 +189,42 @@ def __init__( self.vqgan_decode = model.decode self.vqgan_encode = model.encode - def clone(self): + ################################# + + # self.image_representation_parameters = LatentTensor( + # width=width, + # height=height, + # z=z, + # device=self.device, + # ) + + def clone(self) -> "VQGANImage": + # dummy = VQGANImage(*self.image_shape) + # with torch.no_grad(): + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + # dummy.accum.set_(self.accum.clone()) + # dummy.biased.set_(self.biased.clone()) + # dummy.average.set_(self.average.clone()) + # dummy.decay = self.decay + # return dummy dummy = VQGANImage(*self.image_shape) with torch.no_grad(): - dummy.tensor.set_(self.tensor.clone()) - dummy.accum.set_(self.accum.clone()) - dummy.biased.set_(self.biased.clone()) - dummy.average.set_(self.average.clone()) - dummy.decay = self.decay + # dummy.representation_parameters.set_(self.representation_parameters.clone()) + dummy.image_representation_parameters.set_( + self.image_representation_parameters.clone() + ) return dummy + @property + def representation_parameters(self): + return self.image_representation_parameters._container.get("z").tensor + def get_latent_tensor(self, detach=False, device=None): if device is None: device = self.device - z = self.tensor + # z = self.representation_parameters._container.get("z") + # z = self.image_representation_parameters._container.get("z").tensor + z = self.representation_parameters if detach: z = z.detach() z_q = vector_quantize(z, self.vqgan_quantize_embedding).movedim(3, 1).to(device) @@ -210,6 +236,15 @@ def get_preferred_loss(cls): return LatentLoss + def decode_training_tensor(self): + return self.decode(self.representation_parameters) + + def decode_tensor(self): + # return self.decode(self.average) + # return self.decode(self.representation_parameters.average) + # return self.decode(self.image_representation_parameters.average) + return self.decode(self.image_representation_parameters.average["z"]) + def decode(self, z, device=None): if device is None: device = self.device @@ -226,8 +261,9 @@ def encode_image(self, pil_image, device=None, **kwargs): pil_image = pil_image.resize(self.image_shape, Image.LANCZOS) pil_image = TF.to_tensor(pil_image) z, *_ = self.vqgan_encode(pil_image.unsqueeze(0).to(device) * 2 - 1) - self.tensor.set_(z.movedim(1, 3)) - self.reset() + self.representation_parameters.set_(z.movedim(1, 3)) + # self.reset() + self.image_representation_parameters.reset() @torch.no_grad() def make_latent(self, pil_image, device=None): @@ -245,8 +281,9 @@ def make_latent(self, pil_image, device=None): @torch.no_grad() def encode_random(self): - self.tensor.set_(self.rand_latent()) - self.reset() + self.representation_parameters.set_(self.rand_latent()) + # self.reset() + self.image_representation_parameters.reset() def rand_latent(self, device=None, vqgan_quantize_embedding=None): if device is None: diff --git a/src/pytti/tensor_tools.py b/src/pytti/tensor_tools.py index 6019c11..c3fb353 100644 --- a/src/pytti/tensor_tools.py +++ b/src/pytti/tensor_tools.py @@ -65,6 +65,7 @@ def format_module(module, dest, *args, **kwargs) -> torch.tensor: return format_input(output, module, dest) +# https://pytorch.org/docs/stable/autograd.html#function class ReplaceGrad(torch.autograd.Function): """ returns x_forward during forward pass, but evaluates derivates as though diff --git a/src/pytti/update_func.py b/src/pytti/update_func.py index 1d3606e..b90a98f 100644 --- a/src/pytti/update_func.py +++ b/src/pytti/update_func.py @@ -135,19 +135,11 @@ def save_out( filename = f"backup/{file_namespace}/{base_name}_{n}.bak" torch.save(img.state_dict(), filename) if n > backups: - - # YOOOOOOO let's not start shell processes unnecessarily - # and then execute commands using string interpolation. - # Replace this with a pythonic folder removal, then see - # if we can't deprecate the folder removal entirely. What - # is the purpose of "backups" here? Just use the frames that - # are being written to disk. - subprocess.run( - [ - "rm", - f"backup/{file_namespace}/{base_name}_{n-backups}.bak", - ] - ) + fname = f"{base_name}_{n-backups}.bak" + fpath = Path("backup") / file_namespace / fname + fpath.unlink( + missing_ok=True + ) # delete the file. if file not found, nothing happens. j = i + 1 diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index d579831..43385d2 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -38,7 +38,7 @@ ) from pytti.rotoscoper import ROTOSCOPERS, get_frames -from pytti.image_models import PixelImage, RGBImage, VQGANImage +from pytti.image_models import PixelImage, RGBImage, VQGANImage, DeepImagePrior from pytti.ImageGuide import DirectImageGuide from pytti.Perceptor.Embedder import HDMultiClipEmbedder from pytti.Perceptor.Prompt import parse_prompt @@ -316,6 +316,7 @@ def do_run(): "Limited Palette", "Unlimited Palette", "VQGAN", + "Deep Image Prior", ) # set up image if params.image_model == "Limited Palette": @@ -349,6 +350,9 @@ def do_run(): params.width, params.height, params.pixel_size, device=_device ) img.encode_random() + elif params.image_model == "Deep Image Prior": + img = DeepImagePrior(params.width, params.height, params.pixel_size) + img.encode_random() else: logger.critical( "You should never see this message." @@ -365,6 +369,8 @@ def do_run(): # set up init image # ##################### + logger.debug("configuring init image prompts") + ( init_augs, semantic_init_prompt, @@ -383,16 +389,23 @@ def do_run(): ) # other image prompts + logger.debug("configuring other image prompts") loss_augs.extend( type(img) .get_preferred_loss() - .TargetImage(p.strip(), img.image_shape, is_path=True) + .TargetImage( + p.strip(), + img.image_shape, + is_path=True, + # img_model=type(img) + ) for p in params.direct_image_prompts.split("|") if p.strip() ) # stabilization + logger.debug("configuring stabilization losses") ( loss_augs, img, @@ -456,6 +469,8 @@ def do_run(): # img, # ) = loss_orch.configure_losses() + logger.debug("losses configured.") + # Phase 4 - setup outputs ########################## diff --git a/tests/image_models/test_deep_image_prior.py b/tests/image_models/test_deep_image_prior.py new file mode 100644 index 0000000..19725ce --- /dev/null +++ b/tests/image_models/test_deep_image_prior.py @@ -0,0 +1,43 @@ +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage +from pytti.image_models.deep_image_prior import DeepImagePrior + +### DIP ### + + +def test_dip_init(): + obj = DeepImagePrior(512, 512) + assert obj + + +def test_dip_update(): + obj = DeepImagePrior(512, 512) + obj.update() + + +def test_dip_forward(): + obj = DeepImagePrior(512, 512) + obj.forward() + + +def test_dip_decode_training_tensor(): + obj = DeepImagePrior(512, 512) + obj.decode_training_tensor() + + +def test_dip_decode_tensor(): + obj = DeepImagePrior(512, 512) + obj.decode_tensor() + + +def test_dip_clone(): + obj = DeepImagePrior(512, 512) + obj.clone() + + +def test_dip_get_latent_tensor(): + obj = DeepImagePrior(10, 10) + obj.get_latent_tensor() diff --git a/tests/image_models/test_im_vqgan.py b/tests/image_models/test_im_vqgan.py new file mode 100644 index 0000000..332dd77 --- /dev/null +++ b/tests/image_models/test_im_vqgan.py @@ -0,0 +1,49 @@ +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.vqgan import VQGANImage + +from pathlib import Path + + +def test_init_model(): + models_parent_dir = "~/.cache/vqgan/" + vqgan_model = "coco" + model_artifacts_path = Path(models_parent_dir) / "vqgan" + VQGANImage.init_vqgan(vqgan_model, model_artifacts_path) + img = VQGANImage(512, 512, 1) + # img.encode_random() + + +def test_init(): + obj = VQGANImage(512, 512) + assert obj + + +def test_update(): + obj = VQGANImage(512, 512) + obj.update() + + +def test_forward(): + obj = VQGANImage(512, 512) + obj.forward() + + +def test_decode_training_tensor(): + obj = VQGANImage(512, 512) + obj.decode_training_tensor() + + +def test_decode_tensor(): + obj = VQGANImage(512, 512) + obj.decode_tensor() + + +def test_clone(): + obj = VQGANImage(512, 512) + obj.clone() + + +def test_get_latent_tensor(): + obj = VQGANImage(512, 512) + obj.get_latent_tensor() diff --git a/tests/test_image_models.py b/tests/test_image_models.py new file mode 100644 index 0000000..19725ce --- /dev/null +++ b/tests/test_image_models.py @@ -0,0 +1,43 @@ +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage +from pytti.image_models.deep_image_prior import DeepImagePrior + +### DIP ### + + +def test_dip_init(): + obj = DeepImagePrior(512, 512) + assert obj + + +def test_dip_update(): + obj = DeepImagePrior(512, 512) + obj.update() + + +def test_dip_forward(): + obj = DeepImagePrior(512, 512) + obj.forward() + + +def test_dip_decode_training_tensor(): + obj = DeepImagePrior(512, 512) + obj.decode_training_tensor() + + +def test_dip_decode_tensor(): + obj = DeepImagePrior(512, 512) + obj.decode_tensor() + + +def test_dip_clone(): + obj = DeepImagePrior(512, 512) + obj.clone() + + +def test_dip_get_latent_tensor(): + obj = DeepImagePrior(10, 10) + obj.get_latent_tensor()