torch-tangent
brings to you composable, differentiable, and simple finite-width and infinite-width NTKs for Gaussian Processes, with cross-functionality with GPyTorch.
In the repository, run:
pip install -r requirements.txt
Once requirements are installed, you can simply call,
from torch_tangent.src.grad_ops import *
from torch_tangent.src.ntk_struct import NTK, InfNTK
nn_model = nn.Sequential(
nn.Linear(1, 25)
nn.Linear(25, 1)
)
NTK = NTK(model = nn_model)
# Or, use InfNTK()
# The NTK class inherits the Kernel base-class from GPyTorch, and can be used with most of its schedules.