"""
.. module:: elmo
:synopsis: deep contextualized representation
.. moduleauthor:: Liyuan Liu
"""
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import model_seq.utils as utils
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class EBUnit(nn.Module):
"""
The basic recurrent unit for the ELMo RNNs wrapper.
Parameters
----------
ori_unit : ``torch.nn.Module``, required.
The original module of rnn unit.
droprate : ``float``, required.
The dropout ratrio.
fix_rate: ``bool``, required.
Whether to fix the rqtio.
"""
def __init__(self, ori_unit, droprate, fix_rate):
super(EBUnit, self).__init__()
self.layer = ori_unit.layer
self.droprate = droprate
self.output_dim = ori_unit.output_dim
[docs] def forward(self, x):
"""
Calculate the output.
Parameters
----------
x : ``torch.FloatTensor``, required.
The input tensor, of shape (seq_len, batch_size, input_dim).
Returns
----------
output: ``torch.FloatTensor``.
The output of RNNs.
"""
out, _ = self.layer(x)
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
return out
[docs]class ERNN(nn.Module):
"""
The multi-layer recurrent networks for the ELMo RNNs wrapper.
Parameters
----------
ori_drnn : ``torch.nn.Module``, required.
The original module of rnn networks.
droprate : ``float``, required.
The dropout ratrio.
fix_rate: ``bool``, required.
Whether to fix the rqtio.
"""
def __init__(self, ori_drnn, droprate, fix_rate):
super(ERNN, self).__init__()
self.layer_list = [EBUnit(ori_unit, droprate, fix_rate) for ori_unit in ori_drnn.layer._modules.values()]
self.gamma = nn.Parameter(torch.FloatTensor([1.0]))
self.weight_list = nn.Parameter(torch.FloatTensor([0.0] * len(self.layer_list)))
self.layer = nn.ModuleList(self.layer_list)
for param in self.layer.parameters():
param.requires_grad = False
if fix_rate:
self.gamma.requires_grad = False
self.weight_list.requires_grad = False
self.output_dim = self.layer_list[-1].output_dim
[docs] def regularizer(self):
"""
Calculate the regularization term.
Returns
----------
The regularization term.
"""
srd_weight = self.weight_list - (1.0 / len(self.layer_list))
return (srd_weight ** 2).sum()
[docs] def forward(self, x):
"""
Calculate the output.
Parameters
----------
x : ``torch.FloatTensor``, required.
the input tensor, of shape (seq_len, batch_size, input_dim).
Returns
----------
output: ``torch.FloatTensor``.
The ELMo outputs.
"""
out = 0
nw = self.gamma * F.softmax(self.weight_list, dim=0)
for ind in range(len(self.layer_list)):
x = self.layer[ind](x)
out += x * nw[ind]
return out
[docs]class ElmoLM(nn.Module):
"""
The language model for the ELMo RNNs wrapper.
Parameters
----------
ori_lm : ``torch.nn.Module``, required.
the original module of language model.
backward : ``bool``, required.
whether the language model is backward.
droprate : ``float``, required.
the dropout ratrio.
fix_rate: ``bool``, required.
whether to fix the rqtio.
"""
def __init__(self, ori_lm, backward, droprate, fix_rate):
super(ElmoLM, self).__init__()
self.rnn = ERNN(ori_lm.rnn, droprate, fix_rate)
self.w_num = ori_lm.w_num
self.w_dim = ori_lm.w_dim
self.word_embed = ori_lm.word_embed
self.word_embed.weight.requires_grad = False
self.output_dim = ori_lm.rnn_output
self.backward = backward
[docs] def init_hidden(self):
"""
initialize hidden states.
"""
return
[docs] def regularizer(self):
"""
Calculate the regularization term.
Returns
----------
reg: ``list``.
The list of regularization terms.
"""
return self.rnn.regularizer()
[docs] def prox(self, lambda0):
"""
the proximal calculator.
"""
return 0.0
[docs] def forward(self, w_in, ind=None):
"""
Calculate the output.
Parameters
----------
w_in : ``torch.LongTensor``, required.
the input tensor, of shape (seq_len, batch_size).
ind : ``torch.LongTensor``, optional, (default=None).
the index tensor for the backward language model, of shape (seq_len, batch_size).
Returns
----------
output: ``torch.FloatTensor``.
The ELMo outputs.
"""
w_emb = self.word_embed(w_in)
out = self.rnn(w_emb)
if self.backward:
out_size = out.size()
out = out.view(out_size[0] * out_size[1], out_size[2]).index_select(0, ind).contiguous().view(out_size)
return out