Source code for model_seq.seqlm

"""
.. module:: seqlm
    :synopsis: language model for sequence labeling
 
.. moduleauthor:: Liyuan Liu
"""
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import model_seq.utils as utils

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class BasicSeqLM(nn.Module): """ The language model for the dense rnns. Parameters ---------- ori_lm : ``torch.nn.Module``, required. the original module of language model. backward : ``bool``, required. whether the language model is backward. droprate : ``float``, required. the dropout ratrio. fix_rate: ``bool``, required. whether to fix the rqtio. """ def __init__(self, ori_lm, backward, droprate, fix_rate): super(BasicSeqLM, self).__init__() self.rnn = ori_lm.rnn for param in self.rnn.parameters(): param.requires_grad = False self.w_num = ori_lm.w_num self.w_dim = ori_lm.w_dim self.word_embed = ori_lm.word_embed self.word_embed.weight.requires_grad = False self.output_dim = ori_lm.rnn_output self.backward = backward
[docs] def to_params(self): """ To parameters. """ return { "rnn_params": self.rnn.to_params(), "word_embed_num": self.word_embed.num_embeddings, "word_embed_dim": self.word_embed.embedding_dim }
[docs] def init_hidden(self): """ initialize hidden states. """ self.rnn.init_hidden()
[docs] def regularizer(self): """ Calculate the regularization term. Returns ---------- reg: ``list``. The list of regularization terms. """ return self.rnn.regularizer()
[docs] def forward(self, w_in, ind=None): """ Calculate the output. Parameters ---------- w_in : ``torch.LongTensor``, required. the input tensor, of shape (seq_len, batch_size). ind : ``torch.LongTensor``, optional, (default=None). the index tensor for the backward language model, of shape (seq_len, batch_size). Returns ---------- output: ``torch.FloatTensor``. The ELMo outputs. """ w_emb = self.word_embed(w_in) out = self.rnn(w_emb) if self.backward: out_size = out.size() out = out.view(out_size[0] * out_size[1], out_size[2]).index_select(0, ind).contiguous().view(out_size) return out