Skip to content

Commit

Permalink
replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm (hu…
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored and guyrosin committed Jan 15, 2021
1 parent ab91a5e commit d360a48
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 48 deletions.
30 changes: 10 additions & 20 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, LayerNorm

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)


def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
try:
from apex.normalization import FusedLayerNorm

return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


class BartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
Expand Down Expand Up @@ -321,13 +311,13 @@ def __init__(self, config: BartConfig):
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
"""
Expand Down Expand Up @@ -380,17 +370,17 @@ def __init__(self, config: BartConfig):
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before

self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = BartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(
self,
Expand Down Expand Up @@ -672,9 +662,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None

self.init_weights()

Expand Down Expand Up @@ -812,8 +802,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None

self.init_weights()

Expand Down
12 changes: 1 addition & 11 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, LayerNorm

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -264,16 +264,6 @@
"""


have_fused_layer_norm = False
try:
from apex.normalization import FusedLayerNorm

have_fused_layer_norm = True
except ImportError:
pass
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm


def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
Expand Down
25 changes: 8 additions & 17 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import LayerNorm

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
try:
from apex.normalization import FusedLayerNorm

return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


class ProphetNetPreTrainedModel(PreTrainedModel):
config_class = ProphetNetConfig
base_model_prefix = "prophetnet"
Expand Down Expand Up @@ -1044,11 +1035,11 @@ def __init__(self, config: ProphetNetConfig):
super().__init__()
# 1st residual block
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.self_attn_layer_norm = LayerNorm(config.hidden_size)

# 2nd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

def forward(self, hidden_states, attention_mask):
# 1st residual block
Expand All @@ -1073,16 +1064,16 @@ def __init__(self, config: ProphetNetConfig):
super().__init__()
# 1st residual block
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.self_attn_layer_norm = LayerNorm(config.hidden_size)

# 2nd residual block
if config.add_cross_attention:
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads)
self.cross_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.cross_attn_layer_norm = LayerNorm(config.hidden_size)

# 3rd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

def forward(
self,
Expand Down Expand Up @@ -1154,7 +1145,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
)
self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.embeddings_layer_norm = LayerNorm(config.hidden_size)

self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])

Expand Down Expand Up @@ -1274,7 +1265,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non

self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.embeddings_layer_norm = LayerNorm(config.hidden_size)

self.init_weights()

Expand Down

0 comments on commit d360a48

Please sign in to comment.