Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hugging Face integration #25

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ core
build/
dist/
build.sh
*.job
*.job

# Environments
env/
1 change: 1 addition & 0 deletions hiera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MaskUnitAttention,
Head,
PatchEmbed,
HieraForImageClassification,
)


Expand Down
16 changes: 15 additions & 1 deletion hiera/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------

import importlib.util
import math
from functools import partial
from typing import List, Tuple, Callable, Optional
Expand All @@ -31,6 +32,13 @@
from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll


def is_huggingface_hub_available():
return importlib.util.find_spec("huggingface_hub") is not None


if is_huggingface_hub_available():
from huggingface_hub import PyTorchModelHubMixin


class MaskUnitAttention(nn.Module):
"""
Expand Down Expand Up @@ -225,13 +233,14 @@ def __init__(
patch_padding: Tuple[int, ...] = (3, 3),
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
head_dropout: float = 0.0,
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
):
super().__init__()

norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6)

depth = sum(stages)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)]
Expand Down Expand Up @@ -533,3 +542,8 @@ def hiera_huge_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
)

# Hugging Face integration
class HieraForImageClassification(Hiera, PyTorchModelHubMixin):
def __init__(self, config: dict):
super().__init__(**config)
39 changes: 39 additions & 0 deletions hiera/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

import torch
from hiera import HieraForImageClassification

config = dict(embed_dim=96,
input_size=(224, 224),
in_chans=3,
num_heads=1, # initial number of heads
num_classes=1000,
stages=(1, 2, 7, 2),
q_pool=3, # number of q_pool stages
q_stride=(2, 2),
mask_unit_size=(8, 8), # must divide q_stride ** (#stages-1)
mask_unit_attn=(True, True, False, False),
dim_mul=2.0,
head_mul=2.0,
patch_kernel=(7, 7),
patch_stride=(4, 4),
patch_padding=(3, 3),
mlp_ratio=4.0,
drop_path_rate=0.0,
head_dropout=0.0,
head_init_scale=0.001,
sep_pos_embed=False,)

model = HieraForImageClassification(config)

# load weights
state_dict = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth", map_location="cpu")
model.load_state_dict(state_dict["model_state"])

# save locally
# model.save_pretrained("hiera-tiny-224", config=config)

# save to huggingface hub
# model.push_to_hub("nielsr/hiera-tiny-224", config=config)

# load from huggingface hub
model = HieraForImageClassification.from_pretrained("nielsr/hiera-tiny-224")