# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import math from collections.abc import Iterable import torch import torch.nn as nn from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask from fairseq import utils from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, FairseqEncoderModel, FairseqIncrementalDecoder, register_model, register_model_architecture, ) from fairseq.modules import ( LinearizedConvolution, TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock, ) @register_model("asr_vggtransformer") class VGGTransformerModel(FairseqEncoderDecoderModel): """ Transformers with convolutional context for ASR https://arxiv.org/abs/1904.11660 """ def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--input-feat-per-channel", type=int, metavar="N", help="encoder input dimension per input channel", ) parser.add_argument( "--vggblock-enc-config", type=str, metavar="EXPR", help=""" an array of tuples each containing the configuration of one vggblock: [(out_channels, conv_kernel_size, pooling_kernel_size, num_conv_layers, use_layer_norm), ...]) """, ) parser.add_argument( "--transformer-enc-config", type=str, metavar="EXPR", help="""" a tuple containing the configuration of the encoder transformer layers configurations: [(input_dim, num_heads, ffn_dim, normalize_before, dropout, attention_dropout, relu_dropout), ...]') """, ) parser.add_argument( "--enc-output-dim", type=int, metavar="N", help=""" encoder output dimension, can be None. If specified, projecting the transformer output to the specified dimension""", ) parser.add_argument( "--in-channels", type=int, metavar="N", help="number of encoder input channels", ) parser.add_argument( "--tgt-embed-dim", type=int, metavar="N", help="embedding dimension of the decoder target tokens", ) parser.add_argument( "--transformer-dec-config", type=str, metavar="EXPR", help=""" a tuple containing the configuration of the decoder transformer layers configurations: [(input_dim, num_heads, ffn_dim, normalize_before, dropout, attention_dropout, relu_dropout), ...] """, ) parser.add_argument( "--conv-dec-config", type=str, metavar="EXPR", help=""" an array of tuples for the decoder 1-D convolution config [(out_channels, conv_kernel_size, use_layer_norm), ...]""", ) @classmethod def build_encoder(cls, args, task): return VGGTransformerEncoder( input_feat_per_channel=args.input_feat_per_channel, vggblock_config=eval(args.vggblock_enc_config), transformer_config=eval(args.transformer_enc_config), encoder_output_dim=args.enc_output_dim, in_channels=args.in_channels, ) @classmethod def build_decoder(cls, args, task): return TransformerDecoder( dictionary=task.target_dictionary, embed_dim=args.tgt_embed_dim, transformer_config=eval(args.transformer_dec_config), conv_config=eval(args.conv_dec_config), encoder_output_dim=args.enc_output_dim, ) @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure that all args are properly defaulted # (in case there are any new ones) base_architecture(args) encoder = cls.build_encoder(args, task) decoder = cls.build_decoder(args, task) return cls(encoder, decoder) def get_normalized_probs(self, net_output, log_probs, sample=None): # net_output['encoder_out'] is a (B, T, D) tensor lprobs = super().get_normalized_probs(net_output, log_probs, sample) lprobs.batch_first = True return lprobs DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2 DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2 # 256: embedding dimension # 4: number of heads # 1024: FFN # True: apply layerNorm before (dropout + resiaul) instead of after # 0.2 (dropout): dropout after MultiheadAttention and second FC # 0.2 (attention_dropout): dropout in MultiheadAttention # 0.2 (relu_dropout): dropout after ReLu DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2 DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2 # TODO: repace transformer encoder config from one liner # to explicit args to get rid of this transformation def prepare_transformer_encoder_params( input_dim, num_heads, ffn_dim, normalize_before, dropout, attention_dropout, relu_dropout, ): args = argparse.Namespace() args.encoder_embed_dim = input_dim args.encoder_attention_heads = num_heads args.attention_dropout = attention_dropout args.dropout = dropout args.activation_dropout = relu_dropout args.encoder_normalize_before = normalize_before args.encoder_ffn_embed_dim = ffn_dim return args def prepare_transformer_decoder_params( input_dim, num_heads, ffn_dim, normalize_before, dropout, attention_dropout, relu_dropout, ): args = argparse.Namespace() args.encoder_embed_dim = None args.decoder_embed_dim = input_dim args.decoder_attention_heads = num_heads args.attention_dropout = attention_dropout args.dropout = dropout args.activation_dropout = relu_dropout args.decoder_normalize_before = normalize_before args.decoder_ffn_embed_dim = ffn_dim return args class VGGTransformerEncoder(FairseqEncoder): """VGG + Transformer encoder""" def __init__( self, input_feat_per_channel, vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, encoder_output_dim=512, in_channels=1, transformer_context=None, transformer_sampling=None, ): """constructor for VGGTransformerEncoder Args: - input_feat_per_channel: feature dim (not including stacked, just base feature) - in_channel: # input channels (e.g., if stack 8 feature vector together, this is 8) - vggblock_config: configuration of vggblock, see comments on DEFAULT_ENC_VGGBLOCK_CONFIG - transformer_config: configuration of transformer layer, see comments on DEFAULT_ENC_TRANSFORMER_CONFIG - encoder_output_dim: final transformer output embedding dimension - transformer_context: (left, right) if set, self-attention will be focused on (t-left, t+right) - transformer_sampling: an iterable of int, must match with len(transformer_config), transformer_sampling[i] indicates sampling factor for i-th transformer layer, after multihead att and feedfoward part """ super().__init__(None) self.num_vggblocks = 0 if vggblock_config is not None: if not isinstance(vggblock_config, Iterable): raise ValueError("vggblock_config is not iterable") self.num_vggblocks = len(vggblock_config) self.conv_layers = nn.ModuleList() self.in_channels = in_channels self.input_dim = input_feat_per_channel self.pooling_kernel_sizes = [] if vggblock_config is not None: for _, config in enumerate(vggblock_config): ( out_channels, conv_kernel_size, pooling_kernel_size, num_conv_layers, layer_norm, ) = config self.conv_layers.append( VGGBlock( in_channels, out_channels, conv_kernel_size, pooling_kernel_size, num_conv_layers, input_dim=input_feat_per_channel, layer_norm=layer_norm, ) ) self.pooling_kernel_sizes.append(pooling_kernel_size) in_channels = out_channels input_feat_per_channel = self.conv_layers[-1].output_dim transformer_input_dim = self.infer_conv_output_dim( self.in_channels, self.input_dim ) # transformer_input_dim is the output dimension of VGG part self.validate_transformer_config(transformer_config) self.transformer_context = self.parse_transformer_context(transformer_context) self.transformer_sampling = self.parse_transformer_sampling( transformer_sampling, len(transformer_config) ) self.transformer_layers = nn.ModuleList() if transformer_input_dim != transformer_config[0][0]: self.transformer_layers.append( Linear(transformer_input_dim, transformer_config[0][0]) ) self.transformer_layers.append( TransformerEncoderLayer( prepare_transformer_encoder_params(*transformer_config[0]) ) ) for i in range(1, len(transformer_config)): if transformer_config[i - 1][0] != transformer_config[i][0]: self.transformer_layers.append( Linear(transformer_config[i - 1][0], transformer_config[i][0]) ) self.transformer_layers.append( TransformerEncoderLayer( prepare_transformer_encoder_params(*transformer_config[i]) ) ) self.encoder_output_dim = encoder_output_dim self.transformer_layers.extend( [ Linear(transformer_config[-1][0], encoder_output_dim), LayerNorm(encoder_output_dim), ] ) def forward(self, src_tokens, src_lengths, **kwargs): """ src_tokens: padded tensor (B, T, C * feat) src_lengths: tensor of original lengths of input utterances (B,) """ bsz, max_seq_len, _ = src_tokens.size() x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) x = x.transpose(1, 2).contiguous() # (B, C, T, feat) for layer_idx in range(len(self.conv_layers)): x = self.conv_layers[layer_idx](x) bsz, _, output_seq_len, _ = x.size() # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat) x = x.transpose(1, 2).transpose(0, 1) x = x.contiguous().view(output_seq_len, bsz, -1) input_lengths = src_lengths.clone() for s in self.pooling_kernel_sizes: input_lengths = (input_lengths.float() / s).ceil().long() encoder_padding_mask, _ = lengths_to_encoder_padding_mask( input_lengths, batch_first=True ) if not encoder_padding_mask.any(): encoder_padding_mask = None subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor) transformer_layer_idx = 0 for layer_idx in range(len(self.transformer_layers)): if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer): x = self.transformer_layers[layer_idx]( x, encoder_padding_mask, attn_mask ) if self.transformer_sampling[transformer_layer_idx] != 1: sampling_factor = self.transformer_sampling[transformer_layer_idx] x, encoder_padding_mask, attn_mask = self.slice( x, encoder_padding_mask, attn_mask, sampling_factor ) transformer_layer_idx += 1 else: x = self.transformer_layers[layer_idx](x) # encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate # whether encoder_output[t, b] is valid or not (valid=0, invalid=1) return { "encoder_out": x, # (T, B, C) "encoder_padding_mask": encoder_padding_mask.t() if encoder_padding_mask is not None else None, # (B, T) --> (T, B) } def infer_conv_output_dim(self, in_channels, input_dim): sample_seq_len = 200 sample_bsz = 10 x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) for i, _ in enumerate(self.conv_layers): x = self.conv_layers[i](x) x = x.transpose(1, 2) mb, seq = x.size()[:2] return x.contiguous().view(mb, seq, -1).size(-1) def validate_transformer_config(self, transformer_config): for config in transformer_config: input_dim, num_heads = config[:2] if input_dim % num_heads != 0: msg = ( "ERROR in transformer config {}: ".format(config) + "input dimension {} ".format(input_dim) + "not dividable by number of heads {}".format(num_heads) ) raise ValueError(msg) def parse_transformer_context(self, transformer_context): """ transformer_context can be the following: - None; indicates no context is used, i.e., transformer can access full context - a tuple/list of two int; indicates left and right context, any number <0 indicates infinite context * e.g., (5, 6) indicates that for query at x_t, transformer can access [t-5, t+6] (inclusive) * e.g., (-1, 6) indicates that for query at x_t, transformer can access [0, t+6] (inclusive) """ if transformer_context is None: return None if not isinstance(transformer_context, Iterable): raise ValueError("transformer context must be Iterable if it is not None") if len(transformer_context) != 2: raise ValueError("transformer context must have length 2") left_context = transformer_context[0] if left_context < 0: left_context = None right_context = transformer_context[1] if right_context < 0: right_context = None if left_context is None and right_context is None: return None return (left_context, right_context) def parse_transformer_sampling(self, transformer_sampling, num_layers): """ parsing transformer sampling configuration Args: - transformer_sampling, accepted input: * None, indicating no sampling * an Iterable with int (>0) as element - num_layers, expected number of transformer layers, must match with the length of transformer_sampling if it is not None Returns: - A tuple with length num_layers """ if transformer_sampling is None: return (1,) * num_layers if not isinstance(transformer_sampling, Iterable): raise ValueError( "transformer_sampling must be an iterable if it is not None" ) if len(transformer_sampling) != num_layers: raise ValueError( "transformer_sampling {} does not match with the number " "of layers {}".format(transformer_sampling, num_layers) ) for layer, value in enumerate(transformer_sampling): if not isinstance(value, int): raise ValueError("Invalid value in transformer_sampling: ") if value < 1: raise ValueError( "{} layer's subsampling is {}.".format(layer, value) + " This is not allowed! " ) return transformer_sampling def slice(self, embedding, padding_mask, attn_mask, sampling_factor): """ embedding is a (T, B, D) tensor padding_mask is a (B, T) tensor or None attn_mask is a (T, T) tensor or None """ embedding = embedding[::sampling_factor, :, :] if padding_mask is not None: padding_mask = padding_mask[:, ::sampling_factor] if attn_mask is not None: attn_mask = attn_mask[::sampling_factor, ::sampling_factor] return embedding, padding_mask, attn_mask def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1): """ create attention mask according to sequence lengths and transformer context Args: - input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is the length of b-th sequence - subsampling_factor: int * Note that the left_context and right_context is specified in the input frame-level while input to transformer may already go through subsampling (e.g., the use of striding in vggblock) we use subsampling_factor to scale the left/right context Return: - a (T, T) binary tensor or None, where T is max(input_lengths) * if self.transformer_context is None, None * if left_context is None, * attn_mask[t, t + right_context + 1:] = 1 * others = 0 * if right_context is None, * attn_mask[t, 0:t - left_context] = 1 * others = 0 * elsif * attn_mask[t, t - left_context: t + right_context + 1] = 0 * others = 1 """ if self.transformer_context is None: return None maxT = torch.max(input_lengths).item() attn_mask = torch.zeros(maxT, maxT) left_context = self.transformer_context[0] right_context = self.transformer_context[1] if left_context is not None: left_context = math.ceil(self.transformer_context[0] / subsampling_factor) if right_context is not None: right_context = math.ceil(self.transformer_context[1] / subsampling_factor) for t in range(maxT): if left_context is not None: st = 0 en = max(st, t - left_context) attn_mask[t, st:en] = 1 if right_context is not None: st = t + right_context + 1 st = min(st, maxT - 1) attn_mask[t, st:] = 1 return attn_mask.to(input_lengths.device) def reorder_encoder_out(self, encoder_out, new_order): encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( 1, new_order ) if encoder_out["encoder_padding_mask"] is not None: encoder_out["encoder_padding_mask"] = encoder_out[ "encoder_padding_mask" ].index_select(1, new_order) return encoder_out class TransformerDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs. Default: ``False`` left_pad (bool, optional): whether the input is left-padded. Default: ``False`` """ def __init__( self, dictionary, embed_dim=512, transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, conv_config=DEFAULT_DEC_CONV_CONFIG, encoder_output_dim=512, ): super().__init__(dictionary) vocab_size = len(dictionary) self.padding_idx = dictionary.pad() self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx) self.conv_layers = nn.ModuleList() for i in range(len(conv_config)): out_channels, kernel_size, layer_norm = conv_config[i] if i == 0: conv_layer = LinearizedConv1d( embed_dim, out_channels, kernel_size, padding=kernel_size - 1 ) else: conv_layer = LinearizedConv1d( conv_config[i - 1][0], out_channels, kernel_size, padding=kernel_size - 1, ) self.conv_layers.append(conv_layer) if layer_norm: self.conv_layers.append(nn.LayerNorm(out_channels)) self.conv_layers.append(nn.ReLU()) self.layers = nn.ModuleList() if conv_config[-1][0] != transformer_config[0][0]: self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) self.layers.append( TransformerDecoderLayer( prepare_transformer_decoder_params(*transformer_config[0]) ) ) for i in range(1, len(transformer_config)): if transformer_config[i - 1][0] != transformer_config[i][0]: self.layers.append( Linear(transformer_config[i - 1][0], transformer_config[i][0]) ) self.layers.append( TransformerDecoderLayer( prepare_transformer_decoder_params(*transformer_config[i]) ) ) self.fc_out = Linear(transformer_config[-1][0], vocab_size) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the last decoder layer's output of shape `(batch, tgt_len, vocab)` - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` """ target_padding_mask = ( (prev_output_tokens == self.padding_idx).to(prev_output_tokens.device) if incremental_state is None else None ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] # embed tokens x = self.embed_tokens(prev_output_tokens) # B x T x C -> T x B x C x = self._transpose_if_training(x, incremental_state) for layer in self.conv_layers: if isinstance(layer, LinearizedConvolution): x = layer(x, incremental_state) else: x = layer(x) # B x T x C -> T x B x C x = self._transpose_if_inference(x, incremental_state) # decoder layers for layer in self.layers: if isinstance(layer, TransformerDecoderLayer): x, *_ = layer( x, (encoder_out["encoder_out"] if encoder_out is not None else None), ( encoder_out["encoder_padding_mask"].t() if encoder_out["encoder_padding_mask"] is not None else None ), incremental_state, self_attn_mask=( self.buffered_future_mask(x) if incremental_state is None else None ), self_attn_padding_mask=( target_padding_mask if incremental_state is None else None ), ) else: x = layer(x) # T x B x C -> B x T x C x = x.transpose(0, 1) x = self.fc_out(x) return x, None def buffered_future_mask(self, tensor): dim = tensor.size(0) if ( not hasattr(self, "_future_mask") or self._future_mask is None or self._future_mask.device != tensor.device ): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 ) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 ) return self._future_mask[:dim, :dim] def _transpose_if_training(self, x, incremental_state): if incremental_state is None: x = x.transpose(0, 1) return x def _transpose_if_inference(self, x, incremental_state): if incremental_state: x = x.transpose(0, 1) return x @register_model("asr_vggtransformer_encoder") class VGGTransformerEncoderModel(FairseqEncoderModel): def __init__(self, encoder): super().__init__(encoder) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--input-feat-per-channel", type=int, metavar="N", help="encoder input dimension per input channel", ) parser.add_argument( "--vggblock-enc-config", type=str, metavar="EXPR", help=""" an array of tuples each containing the configuration of one vggblock [(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...] """, ) parser.add_argument( "--transformer-enc-config", type=str, metavar="EXPR", help=""" a tuple containing the configuration of the Transformer layers configurations: [(input_dim, num_heads, ffn_dim, normalize_before, dropout, attention_dropout, relu_dropout), ]""", ) parser.add_argument( "--enc-output-dim", type=int, metavar="N", help="encoder output dimension, projecting the LSTM output", ) parser.add_argument( "--in-channels", type=int, metavar="N", help="number of encoder input channels", ) parser.add_argument( "--transformer-context", type=str, metavar="EXPR", help=""" either None or a tuple of two ints, indicating left/right context a transformer can have access to""", ) parser.add_argument( "--transformer-sampling", type=str, metavar="EXPR", help=""" either None or a tuple of ints, indicating sampling factor in each layer""", ) @classmethod def build_model(cls, args, task): """Build a new model instance.""" base_architecture_enconly(args) encoder = VGGTransformerEncoderOnly( vocab_size=len(task.target_dictionary), input_feat_per_channel=args.input_feat_per_channel, vggblock_config=eval(args.vggblock_enc_config), transformer_config=eval(args.transformer_enc_config), encoder_output_dim=args.enc_output_dim, in_channels=args.in_channels, transformer_context=eval(args.transformer_context), transformer_sampling=eval(args.transformer_sampling), ) return cls(encoder) def get_normalized_probs(self, net_output, log_probs, sample=None): # net_output['encoder_out'] is a (T, B, D) tensor lprobs = super().get_normalized_probs(net_output, log_probs, sample) # lprobs is a (T, B, D) tensor # we need to transoose to get (B, T, D) tensor lprobs = lprobs.transpose(0, 1).contiguous() lprobs.batch_first = True return lprobs class VGGTransformerEncoderOnly(VGGTransformerEncoder): def __init__( self, vocab_size, input_feat_per_channel, vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, encoder_output_dim=512, in_channels=1, transformer_context=None, transformer_sampling=None, ): super().__init__( input_feat_per_channel=input_feat_per_channel, vggblock_config=vggblock_config, transformer_config=transformer_config, encoder_output_dim=encoder_output_dim, in_channels=in_channels, transformer_context=transformer_context, transformer_sampling=transformer_sampling, ) self.fc_out = Linear(self.encoder_output_dim, vocab_size) def forward(self, src_tokens, src_lengths, **kwargs): """ src_tokens: padded tensor (B, T, C * feat) src_lengths: tensor of original lengths of input utterances (B,) """ enc_out = super().forward(src_tokens, src_lengths) x = self.fc_out(enc_out["encoder_out"]) # x = F.log_softmax(x, dim=-1) # Note: no need this line, because model.get_normalized_prob will call # log_softmax return { "encoder_out": x, # (T, B, C) "encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B) } def max_positions(self): """Maximum input length supported by the encoder.""" return (1e6, 1e6) # an arbitrary large number def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) # nn.init.uniform_(m.weight, -0.1, 0.1) # nn.init.constant_(m.weight[padding_idx], 0) return m def Linear(in_features, out_features, bias=True, dropout=0): """Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features, bias=bias) # m.weight.data.uniform_(-0.1, 0.1) # if bias: # m.bias.data.uniform_(-0.1, 0.1) return m def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): """Weight-normalized Conv1d layer optimized for decoding""" m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) nn.init.normal_(m.weight, mean=0, std=std) nn.init.constant_(m.bias, 0) return nn.utils.weight_norm(m, dim=2) def LayerNorm(embedding_dim): m = nn.LayerNorm(embedding_dim) return m # seq2seq models def base_architecture(args): args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) args.vggblock_enc_config = getattr( args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG ) args.transformer_enc_config = getattr( args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG ) args.enc_output_dim = getattr(args, "enc_output_dim", 512) args.in_channels = getattr(args, "in_channels", 1) args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) args.transformer_dec_config = getattr( args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG ) args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG) args.transformer_context = getattr(args, "transformer_context", "None") @register_model_architecture("asr_vggtransformer", "vggtransformer_1") def vggtransformer_1(args): args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) args.vggblock_enc_config = getattr( args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" ) args.transformer_enc_config = getattr( args, "transformer_enc_config", "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14", ) args.enc_output_dim = getattr(args, "enc_output_dim", 1024) args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") args.transformer_dec_config = getattr( args, "transformer_dec_config", "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4", ) @register_model_architecture("asr_vggtransformer", "vggtransformer_2") def vggtransformer_2(args): args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) args.vggblock_enc_config = getattr( args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" ) args.transformer_enc_config = getattr( args, "transformer_enc_config", "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", ) args.enc_output_dim = getattr(args, "enc_output_dim", 1024) args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") args.transformer_dec_config = getattr( args, "transformer_dec_config", "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6", ) @register_model_architecture("asr_vggtransformer", "vggtransformer_base") def vggtransformer_base(args): args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) args.vggblock_enc_config = getattr( args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" ) args.transformer_enc_config = getattr( args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12" ) args.enc_output_dim = getattr(args, "enc_output_dim", 512) args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") args.transformer_dec_config = getattr( args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6" ) # Size estimations: # Encoder: # - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K # Transformer: # - input dimension adapter: 2560 x 512 -> 1.31M # - transformer_layers (x12) --> 37.74M # * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M # * FFN weight: 512*2048*2 = 2.097M # - output dimension adapter: 512 x 512 -> 0.26 M # Decoder: # - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3 # - transformer_layer: (x6) --> 25.16M # * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M # * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M # * FFN: 512*2048*2 = 2.097M # Final FC: # - FC: 512*5000 = 256K (assuming vocab size 5K) # In total: # ~65 M # CTC models def base_architecture_enconly(args): args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) args.vggblock_enc_config = getattr( args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2" ) args.transformer_enc_config = getattr( args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2" ) args.enc_output_dim = getattr(args, "enc_output_dim", 512) args.in_channels = getattr(args, "in_channels", 1) args.transformer_context = getattr(args, "transformer_context", "None") args.transformer_sampling = getattr(args, "transformer_sampling", "None") @register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1") def vggtransformer_enc_1(args): # vggtransformer_1 is the same as vggtransformer_enc_big, except the number # of layers is increased to 16 # keep it here for backward compatiablity purpose args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) args.vggblock_enc_config = getattr( args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" ) args.transformer_enc_config = getattr( args, "transformer_enc_config", "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", ) args.enc_output_dim = getattr(args, "enc_output_dim", 1024)