Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC] Add KFold - External Loop. #8715

Closed
wants to merge 19 commits into from
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))


- Added `KFoldLoop` example ([#8715](https://github.com/PyTorchLightning/pytorch-lightning/pull/8715))


### Changed

- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))
Expand Down
216 changes: 216 additions & 0 deletions pl_examples/loops_customisation/k_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
WARNING: Loop customization is in `pre-alpha release` and the API is likely to change quite a lot !
Please, open issues with your own particular requests, so the Lightning Team can progressively converge to a great API.
"""

from typing import Any, Dict, List, Optional, Type

import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, Subset

from pytorch_lightning import _logger as log
from pytorch_lightning import LightningDataModule, seed_everything
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.loops.external_loop import ExternalLoop
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset

seed_everything(42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seed_everything(42)

rather not seed anything globally



class BaseDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.non_picklable = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, what was the idea here? seems left over :)

self.checkpoint_state: Optional[str] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here :)


self._train_dataset: Optional[Dataset] = None
self._val_dataset: Optional[Dataset] = None
self._test_dataset: Optional[Dataset] = None
self._predict_dataset: Optional[Dataset] = None

self._processed_train_dataset: Optional[Dataset] = None
self._processed_val_dataset: Optional[Dataset] = None
self._processed_test_dataset: Optional[Dataset] = None
self._processed_predict_dataset: Optional[Dataset] = None

@property
def train_dataset(self) -> Optional[Dataset]:
return self._train_dataset

@property
def val_dataset(self) -> Optional[Dataset]:
return self._val_dataset

@property
def test_dataset(self) -> Optional[Dataset]:
return self._test_dataset

@property
def predict_dataset(self) -> Optional[Dataset]:
return self._predict_dataset

@property
def processed_train_dataset(self) -> Optional[Dataset]:
return self._processed_train_dataset or self.train_dataset

@property
def processed_val_dataset(self) -> Optional[Dataset]:
return self._processed_val_dataset or self.val_dataset

@property
def processed_test_dataset(self) -> Optional[Dataset]:
return self._processed_test_dataset or self.test_dataset

@property
def processed_predict_dataset(self) -> Optional[Dataset]:
return self._processed_predict_dataset or self.predict_dataset

@processed_train_dataset.setter
def processed_train_dataset(self, processed_train_dataset) -> None:
self._processed_train_dataset = processed_train_dataset

@processed_val_dataset.setter
def processed_val_dataset(self, processed_val_dataset) -> None:
self._processed_val_dataset = processed_val_dataset

@processed_val_dataset.setter
def processed_val_dataset(self, processed_val_dataset) -> None:
self._processed_val_dataset = processed_val_dataset

@processed_test_dataset.setter
def processed_test_dataset(self, processed_test_dataset) -> None:
self._processed_test_dataset = processed_test_dataset

def train_dataloader(self) -> DataLoader:
return DataLoader(self.processed_train_dataset)

def val_dataloader(self) -> DataLoader:
return DataLoader(self.processed_val_dataset)

def test_dataloader(self) -> DataLoader:
return DataLoader(self.processed_test_dataset)

def predict_dataloader(self) -> DataLoader:
return DataLoader(self.processed_predict_dataset)


class BoringDataModule(BaseDataModule):
def prepare_data(self) -> None:
self.random_full = RandomDataset(32, 64 * 4)

def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
self._train_dataset = Subset(self.random_full, indices=range(64))
self.dims = self._train_dataset[0].shape

if stage in ("fit", "validate") or stage is None:
self._val_dataset = Subset(self.random_full, indices=range(64, 64 * 2))

if stage == "test" or stage is None:
self._test_dataset = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
self.dims = getattr(self, "dims", self._test_dataset[0].shape)

if stage == "predict" or stage is None:
self._predict_dataset = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
self.dims = getattr(self, "dims", self._predict_dataset[0].shape)


class KFoldLoop(ExternalLoop):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
num_folds: int,
best_model_paths: List[str] = [],
restarting: bool = False,
):
super().__init__()
self.num_folds = num_folds
self.best_model_paths = best_model_paths
self.restarting = restarting

@staticmethod
def loop_base_callback() -> Type[Callback]:
class BaseKFoldCallback(Callback):
@rank_zero_only
def on_fold_start(self, trainer, pl_module, counter):
"""Override with your own logic"""

return BaseKFoldCallback
Comment on lines +148 to +154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we define this outside this class but in the file namespace?


@property
def done(self) -> bool:
return self.current_fold >= self.num_folds

def reset(self) -> None:
if not self.restarting:
self.current_fold = 0

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
# temporary hack
self.trainer.datamodule.setup("fit")

def on_advance_start(self) -> None:
# more reproducible as re-creating a different trainer.
self.create_trainer(max_epochs=np.random.randint(10))
# reload dataset for the current fold
dm = self.trainer.datamodule
dm.processed_train_dataset = self.process_dataset("train", dm.train_dataset)
dm.processed_val_dataset = self.process_dataset("val", dm.val_dataset)
# call user hook
self.trainer.call_hook("on_fold_start", self.current_fold)
# reset model parameters
self.trainer.lightning_module.reset_parameters()

def advance(self) -> Any:
# dataloaders will be automatically reloaded
return self.trainer.fit(self.trainer.lightning_module, datamodule=self.trainer.datamodule)

def on_advance_end(self) -> None:
self.current_fold += 1
# stored best weight path for this fold
self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path)

# utilities for creating a hold
def process_dataset(self, stage: str, dataset: Dataset) -> Subset:
kfold = KFold(self.num_folds, random_state=42, shuffle=True)
Copy link
Member

@justusschock justusschock Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is a dependency for sklearn worth it for just this?
Should we maybe have a more general abstract function create_splits the user has to implement? Since there are so many different ways to create data splits. And we then only iterate over the splits here.

train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold]
indices = train_indices if stage == "train" else validation_indices
return Subset(dataset, indices.tolist())

def on_save_checkpoint(self) -> Dict:
return {"current_fold": self.current_fold}

def on_load_checkpoint(self, state_dict) -> None:
self.current_fold = state_dict["current_fold"]


class KFoldCallback(KFoldLoop.loop_base_callback()):

"""This callback demonstrates how to implement your own callback API."""

@rank_zero_only
def on_fold_start(self, trainer, pl_module, counter) -> None:
log.info(f"Starting to train on fold {counter}")


loop = KFoldLoop(5)
model = BoringModel()
datamodule = BoringDataModule()
loop.connect_trainer(max_epochs=10, callbacks=KFoldCallback())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively these could be passed in through init via an argument trainer_kwargs.

loop.run(model, datamodule=datamodule)
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class Loop(ABC):
Expand Down
104 changes: 104 additions & 0 deletions pytorch_lightning/loops/external_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class ExternalLoop(Loop):
"""This Loop is meant wrap trainer calls"""

def __init__(self):
super().__init__()
warning_cache.warn("The ExternalLoop API is a `pre-alpha release` and breaking API changes are expected.")
self.create_trainer = self._wrap_trainer_wrapper(self.create_trainer)
self._has_setup = False
self._restore_external_loop = True

def _wrap_trainer_wrapper(self, create_trainer: Callable) -> Callable:
@functools.wraps(create_trainer)
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
trainer = create_trainer(*args, trainer_kwargs=self.trainer_kwargs, **kwargs)
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException("The `create_trainer` hook should return a Trainer")
self.trainer = trainer
self.trainer.external_loop = self

self.trainer.accelerator.connect(self.__lightning_module)

# links data to the trainer
self.trainer.data_connector.attach_data(
self.trainer.lightning_module,
train_dataloaders=self.__train_dataloader,
val_dataloaders=self.__val_dataloaders,
test_dataloaders=self.__test_dataloaders,
predict_dataloaders=self.__predict_dataloaders,
datamodule=self.__datamodule,
)

# attach model to the training type plugin
self.trainer.data_connector.prepare_data()

self.trainer.checkpoint_connector.resume_start()
self.trainer.checkpoint_connector.restore_loops(restore_external_loop=self._restore_external_loop)
return trainer

return wrapped_func

def connect_trainer(self, **trainer_kwargs: Dict[str, Any]) -> None:
self.trainer_kwargs = trainer_kwargs

def create_trainer(self, *args, trainer_kwargs: Dict[str, Any] = {}, **kwargs) -> "pl.Trainer":
trainer_kwargs.update(kwargs)
return pl.Trainer(*args, **trainer_kwargs)

def run(
self,
model: "pl.LightningModule",
train_dataloader=None,
val_dataloaders=None,
test_dataloaders=None,
predict_dataloaders=None,
datamodule=None,
):

self.__lightning_module = model
self.__train_dataloader = train_dataloader
self.__val_dataloaders = val_dataloaders
self.__test_dataloaders = test_dataloaders
self.__predict_dataloaders = predict_dataloaders
self.__datamodule = datamodule

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, pl.LightningDataModule):
datamodule = train_dataloader
train_dataloader = None

if train_dataloader is not None and datamodule:
raise MisconfigurationException("You cannot pass both `loop.run(dataloaders=..., datamodule=...)`")

if model is None:
raise MisconfigurationException("`model` must be provided to `loop.run()`")

if self._trainer is None:
self.create_trainer()
self._restore_external_loop = False

return super().run()
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def on_keyboard_interrupt(self):
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def user_defined_hook(self, hook_name: str, *args, **kwargs):
"""Called when a user calls call_hook directly with its own hook name."""
for callback in self.callbacks:
if hasattr(callback, hook_name):
getattr(callback, hook_name)(self, self.lightning_module, *args, **kwargs)

@staticmethod
def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def restore_callbacks(self) -> None:
)
self.trainer.on_load_checkpoint(self._loaded_checkpoint)

def restore_loops(self) -> None:
def restore_loops(self, restore_external_loop: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's probably enough if this is controlled below by the existence of an external loop as checked below.

"""
Restores the loop progress from the pre-loaded checkpoint.
Calls hooks on the loops to give it a chance to restore its state from the checkpoint.
Expand Down Expand Up @@ -226,6 +226,11 @@ def restore_loops(self) -> None:
self.trainer.test_loop.load_state_dict(state_dict["test_loop"])
self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"])

if restore_external_loop:
external_loop = getattr(self.trainer, "external_loop", None)
if external_loop:
self.trainer.external_loop.load_state_dict(state_dict["external_loop"])

def restore_optimizers_and_schedulers(self) -> None:
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint."""
if (
Expand Down Expand Up @@ -471,9 +476,13 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
return state_dict

def _get_loops_state_dict(self) -> Dict[str, Any]:
return {
state_dict = {
"fit_loop": self.trainer.fit_loop.state_dict(),
"validate_loop": self.trainer.validate_loop.state_dict(),
"test_loop": self.trainer.test_loop.state_dict(),
"predict_loop": self.trainer.predict_loop.state_dict(),
}
external_loop = getattr(self.trainer, "external_loop", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps trainer can have a property like one we have for the other loops?

if external_loop:
state_dict.update({"external_loop": external_loop.state_dict()})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can there be more than one external loop. I mean, one external loop nested inside another?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could and Loop will automatically gather their children states.

return state_dict
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,8 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
if hasattr(self, hook_name):
trainer_hook = getattr(self, hook_name)
trainer_hook(*args, **kwargs)
else:
self.user_defined_hook(hook_name, *args, **kwargs)

# next call hook in lightningModule
output = None
Expand Down
Loading