From 0305a57700169d8f7dc07c2c044709ca3489213c Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 11:06:25 -0700 Subject: [PATCH 1/8] added some image model unit tests --- tests/test_image_models.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/test_image_models.py diff --git a/tests/test_image_models.py b/tests/test_image_models.py new file mode 100644 index 0000000..f574432 --- /dev/null +++ b/tests/test_image_models.py @@ -0,0 +1,76 @@ +import pytest +from loguru import logger +import torch + +import pytti.image_models +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage + + +@pytest.mark.parametrize( + "ImageModel", + [DifferentiableImage, RGBImage], +) +def test_simple_image_models(ImageModel): + """ + Test that the image models can be instantiated + """ + image = ImageModel( + width=10, + height=10, + ) + assert image + + +def test_ema_image(): + """ + Test that the EMAImage can be instantiated + """ + image = EMAImage( + width=10, + height=10, + tensor=torch.zeros(10, 10), + decay=0.5, + ) + assert image + + +def test_pixel_image(): + """ + Test that the PixelImage can be instantiated + """ + image = PixelImage( + width=10, + height=10, + scale=1, + pallet_size=1, + n_pallets=1, + ) + assert image + + +# def test_vqgan_image_valid(): +# """ +# Test that the VQGANImage can be instantiated +# """ +# image = VQGANImage( +# width=10, +# height=10, +# model=SOME_VQGAN_MODEL, +# ) +# assert image + + +def test_vqgan_image_invalid_string(): + """ + Test that the VQGANImage can be instantiated + """ + with pytest.raises(AttributeError): + image = VQGANImage( + width=10, + height=10, + model="this isn't actually a valid value for this field", + ) From 8ae22dde60c7e748633d9bb7d38579d1f7a66193 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 11:19:37 -0700 Subject: [PATCH 2/8] cleaned up tests some, added some logging for my own curiosity --- tests/test_image_models.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/tests/test_image_models.py b/tests/test_image_models.py index f574432..8c8b1c2 100644 --- a/tests/test_image_models.py +++ b/tests/test_image_models.py @@ -10,21 +10,36 @@ from pytti.image_models.vqgan import VQGANImage -@pytest.mark.parametrize( - "ImageModel", - [DifferentiableImage, RGBImage], -) -def test_simple_image_models(ImageModel): +## simple models ## + + +def test_differentiabble_image_model(): + """ + Test that the DifferentiableImage can be instantiated + """ + image = DifferentiableImage( + width=10, + height=10, + ) + logger.debug(image.output_axes) # x y s + assert image + + +def test_rgb_image_model(): """ - Test that the image models can be instantiated + Test that the RGBImage can be instantiated """ - image = ImageModel( + image = RGBImage( width=10, height=10, ) + logger.debug(image.output_axes) # n x y s ... when does n != 1? assert image +## more complex models ## + + def test_ema_image(): """ Test that the EMAImage can be instantiated @@ -35,6 +50,7 @@ def test_ema_image(): tensor=torch.zeros(10, 10), decay=0.5, ) + logger.debug(image.output_axes) # x y s assert image @@ -49,6 +65,7 @@ def test_pixel_image(): pallet_size=1, n_pallets=1, ) + logger.debug(image.output_axes) # n s y x ... uh ok, sure. assert image @@ -61,6 +78,7 @@ def test_pixel_image(): # height=10, # model=SOME_VQGAN_MODEL, # ) +# logger.debug(image.output_axes) # assert image @@ -74,3 +92,4 @@ def test_vqgan_image_invalid_string(): height=10, model="this isn't actually a valid value for this field", ) + logger.debug(image.output_axes) From c52b0a8015f1d4ce11314b850366e4242ccda3a5 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 19:33:56 -0700 Subject: [PATCH 3/8] isolated cutout factory --- src/pytti/Perceptor/Embedder.py | 91 ++++++------------ src/pytti/Perceptor/cutouts/__init__.py | 0 src/pytti/Perceptor/samplers.py | 117 ++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 62 deletions(-) create mode 100644 src/pytti/Perceptor/cutouts/__init__.py create mode 100644 src/pytti/Perceptor/samplers.py diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index 3af72cf..9123422 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -12,6 +12,8 @@ import kornia.augmentation as K +from .samplers import pytti_classic + PADDING_MODES = { "mirror": "reflect", "smear": "replicate", @@ -64,69 +66,34 @@ def __init__( self.border_mode = border_mode def make_cutouts( - self, input: torch.Tensor, side_x, side_y, cut_size, device=DEVICE + self, + input: torch.Tensor, + side_x, + side_y, + cut_size, + #### + # padding, + # cutn, + # cut_pow, + # border_mode, + # augs, + # noise_fac, + #### + device=DEVICE, ) -> Tuple[list, list, list]: - min_size = min(side_x, side_y, cut_size) - max_size = min(side_x, side_y) - paddingx = min(round(side_x * self.padding), side_x) - paddingy = min(round(side_y * self.padding), side_y) - cutouts = [] - offsets = [] - sizes = [] - for _ in range(self.cutn): - # mean is 0.8 - # varience is 0.3 - size = int( - max_size - * ( - torch.zeros( - 1, - ) - .normal_(mean=0.8, std=0.3) - .clip(cut_size / max_size, 1.0) - ** self.cut_pow - ) - ) - offsetx_max = side_x - size + 1 - offsety_max = side_y - size + 1 - if self.border_mode == "clamp": - offsetx = torch.clamp( - (torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx) - .floor() - .int(), - 0, - offsetx_max, - ) - offsety = torch.clamp( - (torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy) - .floor() - .int(), - 0, - offsety_max, - ) - cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] - else: - px = min(size, paddingx) - py = min(size, paddingy) - offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int() - offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int() - cutout = input[ - :, - :, - paddingy + offsety : paddingy + offsety + size, - paddingx + offsetx : paddingx + offsetx + size, - ] - cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size)) - offsets.append( - torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device) - ) - sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device)) - cutouts = self.augs(torch.cat(cutouts)) - offsets = torch.cat(offsets) - sizes = torch.cat(sizes) - if self.noise_fac: - facs = cutouts.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) - cutouts.add_(facs * torch.randn_like(cutouts)) + cutouts, offsets, sizes = pytti_classic( + input=input, + side_x=side_x, + side_y=side_y, + cut_size=cut_size, + padding=self.padding, + cutn=self.cutn, + cut_pow=self.cut_pow, + border_mode=self.border_mode, + augs=self.augs, + noise_fac=self.noise_fac, + device=DEVICE, + ) return cutouts, offsets, sizes def forward( diff --git a/src/pytti/Perceptor/cutouts/__init__.py b/src/pytti/Perceptor/cutouts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pytti/Perceptor/samplers.py b/src/pytti/Perceptor/samplers.py new file mode 100644 index 0000000..086680d --- /dev/null +++ b/src/pytti/Perceptor/samplers.py @@ -0,0 +1,117 @@ +""" +Methods for obtaining cutouts, agnostic to augmentations. + +Cutout choices have a significant impact on the performance of the perceptors and the +overall look of the image. + +The objects defined here probably are only being used in pytti.Perceptor.cutouts.Embedder.HDMultiClipEmbedder, but they +should be sufficiently general for use in notebooks without pyttitools otherwise in use. +""" + +import torch +from typing import Tuple +from torch.nn import functional as F + +PADDING_MODES = { + "mirror": "reflect", + "smear": "replicate", + "wrap": "circular", + "black": "constant", +} + +# ( +# cut_size = 64 +# cut_pow = 0.5 +# noise_fac = 0.0 +# cutn = 8 +# border_mode = "clamp" +# augs = None +# return Cutout( +# cut_size=cut_size, +# cut_pow=cut_pow, +# noise_fac=noise_fac, +# cutn=cutn, +# border_mode=border_mode, +# augs=augs, +# ) + + +def pytti_classic( + # self, + input: torch.Tensor, + side_x, + side_y, + cut_size, + padding, + cutn, + cut_pow, + border_mode, + augs, + noise_fac, + device, +) -> Tuple[list, list, list]: + """ + This is the cutout method that was already in use in the original pytti. + """ + min_size = min(side_x, side_y, cut_size) + max_size = min(side_x, side_y) + paddingx = min(round(side_x * padding), side_x) + paddingy = min(round(side_y * padding), side_y) + cutouts = [] + offsets = [] + sizes = [] + for _ in range(cutn): + # mean is 0.8 + # varience is 0.3 + size = int( + max_size + * ( + torch.zeros( + 1, + ) + .normal_(mean=0.8, std=0.3) + .clip(cut_size / max_size, 1.0) + ** cut_pow + ) + ) + offsetx_max = side_x - size + 1 + offsety_max = side_y - size + 1 + if border_mode == "clamp": + offsetx = torch.clamp( + (torch.rand([]) * (offsetx_max + 2 * paddingx) - paddingx) + .floor() + .int(), + 0, + offsetx_max, + ) + offsety = torch.clamp( + (torch.rand([]) * (offsety_max + 2 * paddingy) - paddingy) + .floor() + .int(), + 0, + offsety_max, + ) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + else: + px = min(size, paddingx) + py = min(size, paddingy) + offsetx = (torch.rand([]) * (offsetx_max + 2 * px) - px).floor().int() + offsety = (torch.rand([]) * (offsety_max + 2 * py) - py).floor().int() + cutout = input[ + :, + :, + paddingy + offsety : paddingy + offsety + size, + paddingx + offsetx : paddingx + offsetx + size, + ] + cutouts.append(F.adaptive_avg_pool2d(cutout, cut_size)) + offsets.append( + torch.as_tensor([[offsetx / side_x, offsety / side_y]]).to(device) + ) + sizes.append(torch.as_tensor([[size / side_x, size / side_y]]).to(device)) + cutouts = augs(torch.cat(cutouts)) + offsets = torch.cat(offsets) + sizes = torch.cat(sizes) + if noise_fac: + facs = cutouts.new_empty([cutn, 1, 1, 1]).uniform_(0, noise_fac) + cutouts.add_(facs * torch.randn_like(cutouts)) + return cutouts, offsets, sizes From 512e6515e7f2315f2ea829dfb46dbc37e3af8df0 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 19:44:28 -0700 Subject: [PATCH 4/8] refactor --- src/pytti/Perceptor/Embedder.py | 2 +- src/pytti/Perceptor/{ => cutouts}/samplers.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/pytti/Perceptor/{ => cutouts}/samplers.py (100%) diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index 9123422..fd2d17e 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -12,7 +12,7 @@ import kornia.augmentation as K -from .samplers import pytti_classic +from .cutouts.samplers import pytti_classic PADDING_MODES = { "mirror": "reflect", diff --git a/src/pytti/Perceptor/samplers.py b/src/pytti/Perceptor/cutouts/samplers.py similarity index 100% rename from src/pytti/Perceptor/samplers.py rename to src/pytti/Perceptor/cutouts/samplers.py From 288d1609003a9542eabe0afca7d28348debb45ad Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 20:16:35 -0700 Subject: [PATCH 5/8] refactor cutout augs --- src/pytti/Perceptor/Embedder.py | 24 ++++++++---------------- src/pytti/Perceptor/cutouts/augs.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 16 deletions(-) create mode 100644 src/pytti/Perceptor/cutouts/augs.py diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index fd2d17e..b44b91a 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -10,9 +10,13 @@ from torch import nn from torch.nn import functional as F -import kornia.augmentation as K -from .cutouts.samplers import pytti_classic +# import .cutouts +# import .cutouts as cutouts +# import cutouts + +from .cutouts import augs as cutouts_augs +from .cutouts import samplers as cutouts_samplers PADDING_MODES = { "mirror": "reflect", @@ -45,19 +49,7 @@ def __init__( self.cut_sizes = [p.visual.input_resolution for p in perceptors] self.cutn = cutn self.noise_fac = noise_fac - self.augs = nn.Sequential( - K.RandomHorizontalFlip(p=0.3), - K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), - K.RandomPerspective( - 0.2, - p=0.4, - ), - K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), - K.RandomErasing( - scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7 - ), - nn.Identity(), - ) + self.augs = cutouts_augs.pytti_classic() self.input_axes = ("n", "s", "y", "x") self.output_axes = ("c", "n", "i") self.perceptors = perceptors @@ -81,7 +73,7 @@ def make_cutouts( #### device=DEVICE, ) -> Tuple[list, list, list]: - cutouts, offsets, sizes = pytti_classic( + cutouts, offsets, sizes = cutouts_samplers.pytti_classic( input=input, side_x=side_x, side_y=side_y, diff --git a/src/pytti/Perceptor/cutouts/augs.py b/src/pytti/Perceptor/cutouts/augs.py new file mode 100644 index 0000000..32873dd --- /dev/null +++ b/src/pytti/Perceptor/cutouts/augs.py @@ -0,0 +1,18 @@ +import kornia.augmentation as K +from torch import nn + + +def pytti_classic(): + return nn.Sequential( + K.RandomHorizontalFlip(p=0.3), + K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), + K.RandomPerspective( + 0.2, + p=0.4, + ), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + K.RandomErasing( + scale=(0.1, 0.4), ratio=(0.3, 1 / 0.3), same_on_batch=False, p=0.7 + ), + nn.Identity(), + ) From 4f2e5294ef84f45dae9bb2ea549bef798b996039 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 11:55:37 -0700 Subject: [PATCH 6/8] added output type hints to diffimage methods --- .../image_models/differentiable_image.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/pytti/image_models/differentiable_image.py b/src/pytti/image_models/differentiable_image.py index 79ea590..c1e35b9 100644 --- a/src/pytti/image_models/differentiable_image.py +++ b/src/pytti/image_models/differentiable_image.py @@ -4,6 +4,10 @@ from PIL import Image from pytti.tensor_tools import named_rearrange +# for typing +import torch +from pytti.LossAug.BaseLossClass import Loss + SUPPORTED_MODES = ["L", "RGB", "I", "F"] @@ -25,13 +29,13 @@ def __init__(self, width: int, height: int, pixel_format: str = "RGB"): self.lr = 0.02 self.latent_strength = 0 - def decode_training_tensor(self): + def decode_training_tensor(self) -> torch.Tensor: """ returns a decoded tensor of this image for training """ return self.decode_tensor() - def get_image_tensor(self): + def get_image_tensor(self) -> torch.Tensor: """ optional method: returns an [n x w_i x h_i] tensor representing the local image data those data will be used for animation if afforded @@ -41,26 +45,26 @@ def get_image_tensor(self): def clone(self): raise NotImplementedError - def get_latent_tensor(self, detach=False): + def get_latent_tensor(self, detach=False) -> torch.Tensor: if detach: return self.get_image_tensor().detach() else: return self.get_image_tensor() - def set_image_tensor(self, tensor): + def set_image_tensor(self, tensor: torch.Tensor): """ optional method: accepts an [n x w_i x h_i] tensor representing the local image data those data will be by the animation system """ raise NotImplementedError - def decode_tensor(self): + def decode_tensor(self) -> torch.Tensor: """ returns a decoded tensor of this image """ raise NotImplementedError - def encode_image(self, pil_image): + def encode_image(self, pil_image: Image): """ overwrites this image with the input image pil_image: (Image) input image @@ -79,7 +83,7 @@ def update(self): """ pass - def make_latent(self, pil_image): + def make_latent(self, pil_image: Image) -> torch.Tensor: try: dummy = self.clone() except NotImplementedError: @@ -88,7 +92,7 @@ def make_latent(self, pil_image): return dummy.get_latent_tensor(detach=True) @classmethod - def get_preferred_loss(cls): + def get_preferred_loss(cls) -> Loss: from pytti.LossAug.HSVLossClass import HSVLoss return HSVLoss @@ -96,7 +100,7 @@ def get_preferred_loss(cls): def image_loss(self): return [] - def decode_image(self): + def decode_image(self) -> Image: """ render a PIL Image version of this image """ @@ -112,7 +116,7 @@ def decode_image(self): ) return Image.fromarray(array) - def forward(self): + def forward(self) -> torch.Tensor: """ returns a decoded tensor of this image """ From 805ef7e0b60fc06111b4057beed1c19deeca1b7b Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 12:01:36 -0700 Subject: [PATCH 7/8] added logging preferred loss introspection --- tests/test_image_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_image_models.py b/tests/test_image_models.py index 8c8b1c2..4bb6172 100644 --- a/tests/test_image_models.py +++ b/tests/test_image_models.py @@ -17,11 +17,15 @@ def test_differentiabble_image_model(): """ Test that the DifferentiableImage can be instantiated """ + logger.debug( + DifferentiableImage.get_preferred_loss() + ) # pytti.LossAug.HSVLossClass.HSVLoss image = DifferentiableImage( width=10, height=10, ) logger.debug(image.output_axes) # x y s + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -29,11 +33,13 @@ def test_rgb_image_model(): """ Test that the RGBImage can be instantiated """ + logger.debug(RGBImage.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss image = RGBImage( width=10, height=10, ) logger.debug(image.output_axes) # n x y s ... when does n != 1? + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -44,6 +50,7 @@ def test_ema_image(): """ Test that the EMAImage can be instantiated """ + logger.debug(EMAImage.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss image = EMAImage( width=10, height=10, @@ -51,6 +58,7 @@ def test_ema_image(): decay=0.5, ) logger.debug(image.output_axes) # x y s + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -58,6 +66,7 @@ def test_pixel_image(): """ Test that the PixelImage can be instantiated """ + logger.debug(PixelImage.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss image = PixelImage( width=10, height=10, @@ -66,6 +75,7 @@ def test_pixel_image(): n_pallets=1, ) logger.debug(image.output_axes) # n s y x ... uh ok, sure. + # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -86,6 +96,9 @@ def test_vqgan_image_invalid_string(): """ Test that the VQGANImage can be instantiated """ + logger.debug( + VQGANImage.get_preferred_loss() + ) # pytti.LossAug.LatentLossClass.LatentLoss with pytest.raises(AttributeError): image = VQGANImage( width=10, From 42817cefc0455d15c690c095d9e6922455031d09 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 14:36:52 -0700 Subject: [PATCH 8/8] added image_model.lr logging. vqgan is only model with non-standard default lr? --- tests/test_image_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_image_models.py b/tests/test_image_models.py index 4bb6172..4d779ec 100644 --- a/tests/test_image_models.py +++ b/tests/test_image_models.py @@ -25,6 +25,7 @@ def test_differentiabble_image_model(): height=10, ) logger.debug(image.output_axes) # x y s + logger.debug(image.lr) # 0.02 # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -39,6 +40,7 @@ def test_rgb_image_model(): height=10, ) logger.debug(image.output_axes) # n x y s ... when does n != 1? + logger.debug(image.lr) # 0.02 # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -58,6 +60,7 @@ def test_ema_image(): decay=0.5, ) logger.debug(image.output_axes) # x y s + logger.debug(image.lr) # 0.02 # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -75,6 +78,7 @@ def test_pixel_image(): n_pallets=1, ) logger.debug(image.output_axes) # n s y x ... uh ok, sure. + logger.debug(image.lr) # 0.02 # logger.debug(image.get_preferred_loss()) # pytti.LossAug.HSVLossClass.HSVLoss assert image @@ -89,6 +93,7 @@ def test_pixel_image(): # model=SOME_VQGAN_MODEL, # ) # logger.debug(image.output_axes) +# logger.debug(image.lr) ### self.lr = 0.15 if VQGAN_IS_GUMBEL else 0.1 # assert image