From 0305a57700169d8f7dc07c2c044709ca3489213c Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 11:06:25 -0700 Subject: [PATCH 1/2] added some image model unit tests --- tests/test_image_models.py | 76 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/test_image_models.py diff --git a/tests/test_image_models.py b/tests/test_image_models.py new file mode 100644 index 0000000..f574432 --- /dev/null +++ b/tests/test_image_models.py @@ -0,0 +1,76 @@ +import pytest +from loguru import logger +import torch + +import pytti.image_models +from pytti.image_models.differentiable_image import DifferentiableImage +from pytti.image_models.ema import EMAImage +from pytti.image_models.pixel import PixelImage +from pytti.image_models.rgb_image import RGBImage +from pytti.image_models.vqgan import VQGANImage + + +@pytest.mark.parametrize( + "ImageModel", + [DifferentiableImage, RGBImage], +) +def test_simple_image_models(ImageModel): + """ + Test that the image models can be instantiated + """ + image = ImageModel( + width=10, + height=10, + ) + assert image + + +def test_ema_image(): + """ + Test that the EMAImage can be instantiated + """ + image = EMAImage( + width=10, + height=10, + tensor=torch.zeros(10, 10), + decay=0.5, + ) + assert image + + +def test_pixel_image(): + """ + Test that the PixelImage can be instantiated + """ + image = PixelImage( + width=10, + height=10, + scale=1, + pallet_size=1, + n_pallets=1, + ) + assert image + + +# def test_vqgan_image_valid(): +# """ +# Test that the VQGANImage can be instantiated +# """ +# image = VQGANImage( +# width=10, +# height=10, +# model=SOME_VQGAN_MODEL, +# ) +# assert image + + +def test_vqgan_image_invalid_string(): + """ + Test that the VQGANImage can be instantiated + """ + with pytest.raises(AttributeError): + image = VQGANImage( + width=10, + height=10, + model="this isn't actually a valid value for this field", + ) From 8ae22dde60c7e748633d9bb7d38579d1f7a66193 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 11:19:37 -0700 Subject: [PATCH 2/2] cleaned up tests some, added some logging for my own curiosity --- tests/test_image_models.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/tests/test_image_models.py b/tests/test_image_models.py index f574432..8c8b1c2 100644 --- a/tests/test_image_models.py +++ b/tests/test_image_models.py @@ -10,21 +10,36 @@ from pytti.image_models.vqgan import VQGANImage -@pytest.mark.parametrize( - "ImageModel", - [DifferentiableImage, RGBImage], -) -def test_simple_image_models(ImageModel): +## simple models ## + + +def test_differentiabble_image_model(): + """ + Test that the DifferentiableImage can be instantiated + """ + image = DifferentiableImage( + width=10, + height=10, + ) + logger.debug(image.output_axes) # x y s + assert image + + +def test_rgb_image_model(): """ - Test that the image models can be instantiated + Test that the RGBImage can be instantiated """ - image = ImageModel( + image = RGBImage( width=10, height=10, ) + logger.debug(image.output_axes) # n x y s ... when does n != 1? assert image +## more complex models ## + + def test_ema_image(): """ Test that the EMAImage can be instantiated @@ -35,6 +50,7 @@ def test_ema_image(): tensor=torch.zeros(10, 10), decay=0.5, ) + logger.debug(image.output_axes) # x y s assert image @@ -49,6 +65,7 @@ def test_pixel_image(): pallet_size=1, n_pallets=1, ) + logger.debug(image.output_axes) # n s y x ... uh ok, sure. assert image @@ -61,6 +78,7 @@ def test_pixel_image(): # height=10, # model=SOME_VQGAN_MODEL, # ) +# logger.debug(image.output_axes) # assert image @@ -74,3 +92,4 @@ def test_vqgan_image_invalid_string(): height=10, model="this isn't actually a valid value for this field", ) + logger.debug(image.output_axes)