Source code for model_seq.crf

"""
.. module:: crf
    :synopsis: conditional random field
 
.. moduleauthor:: Liyuan Liu
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.sparse as sparse
import model_seq.utils as utils

[docs]class CRF(nn.Module): """ Conditional Random Field Module Parameters ---------- hidden_dim : ``int``, required. the dimension of the input features. tagset_size : ``int``, required. the size of the target labels. if_bias: ``bool``, optional, (default=True). whether the linear transformation has the bias term. """ def __init__(self, hidden_dim: int, tagset_size: int, if_bias: bool = True): super(CRF, self).__init__() self.tagset_size = tagset_size self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size, bias=if_bias) self.transitions = nn.Parameter(torch.Tensor(self.tagset_size, self.tagset_size))
[docs] def rand_init(self): """ random initialization """ utils.init_linear(self.hidden2tag) self.transitions.data.zero_()
[docs] def forward(self, feats): """ calculate the potential score for the conditional random field. Parameters ---------- feats: ``torch.FloatTensor``, required. the input features for the conditional random field, of shape (*, hidden_dim). Returns ------- output: ``torch.FloatTensor``. A float tensor of shape (ins_num, from_tag_size, to_tag_size) """ scores = self.hidden2tag(feats).view(-1, 1, self.tagset_size) ins_num = scores.size(0) crf_scores = scores.expand(ins_num, self.tagset_size, self.tagset_size) + self.transitions.view(1, self.tagset_size, self.tagset_size).expand(ins_num, self.tagset_size, self.tagset_size) return crf_scores
[docs]class CRFLoss(nn.Module): """ The negative loss for the Conditional Random Field Module Parameters ---------- y_map : ``dict``, required. a ``dict`` maps from tag string to tag index. average_batch : ``bool``, optional, (default=True). whether the return score would be averaged per batch. """ def __init__(self, y_map: dict, average_batch: bool = True): super(CRFLoss, self).__init__() self.tagset_size = len(y_map) self.start_tag = y_map['<s>'] self.end_tag = y_map['<eof>'] self.average_batch = average_batch
[docs] def forward(self, scores, target, mask): """ calculate the negative log likehood for the conditional random field. Parameters ---------- scores: ``torch.FloatTensor``, required. the potential score for the conditional random field, of shape (seq_len, batch_size, from_tag_size, to_tag_size). target: ``torch.LongTensor``, required. the positive path for the conditional random field, of shape (seq_len, batch_size). mask: ``torch.ByteTensor``, required. the mask for the unpadded sentence parts, of shape (seq_len, batch_size). Returns ------- loss: ``torch.FloatTensor``. The NLL loss. """ seq_len = scores.size(0) bat_size = scores.size(1) tg_energy = torch.gather(scores.view(seq_len, bat_size, -1), 2, target.unsqueeze(2)).view(seq_len, bat_size) tg_energy = tg_energy.masked_select(mask).sum() seq_iter = enumerate(scores) _, inivalues = seq_iter.__next__() partition = inivalues[:, self.start_tag, :].squeeze(1).clone() for idx, cur_values in seq_iter: cur_values = cur_values + partition.unsqueeze(2).expand(bat_size, self.tagset_size, self.tagset_size) cur_partition = utils.log_sum_exp(cur_values) mask_idx = mask[idx, :].view(bat_size, 1).expand(bat_size, self.tagset_size) partition.masked_scatter_(mask_idx, cur_partition.masked_select(mask_idx)) partition = partition[:, self.end_tag].sum() if self.average_batch: return (partition - tg_energy) / bat_size else: return (partition - tg_energy)
[docs]class CRFDecode(): """ The negative loss for the Conditional Random Field Module Parameters ---------- y_map : ``dict``, required. a ``dict`` maps from tag string to tag index. """ def __init__(self, y_map: dict): self.tagset_size = len(y_map) self.start_tag = y_map['<s>'] self.end_tag = y_map['<eof>'] self.y_map = y_map self.r_y_map = {v:k for k, v in self.y_map.items()}
[docs] def decode(self, scores, mask): """ find the best path from the potential scores by the viterbi decoding algorithm. Parameters ---------- scores: ``torch.FloatTensor``, required. the potential score for the conditional random field, of shape (seq_len, batch_size, from_tag_size, to_tag_size). mask: ``torch.ByteTensor``, required. the mask for the unpadded sentence parts, of shape (seq_len, batch_size). Returns ------- output: ``torch.LongTensor``. A LongTensor of shape (seq_len - 1, batch_size) """ seq_len = scores.size(0) bat_size = scores.size(1) mask = 1 - mask.data decode_idx = torch.LongTensor(seq_len-1, bat_size) seq_iter = enumerate(scores) _, inivalues = seq_iter.__next__() forscores = inivalues[:, self.start_tag, :] back_points = list() for idx, cur_values in seq_iter: cur_values = cur_values + forscores.contiguous().view(bat_size, self.tagset_size, 1).expand(bat_size, self.tagset_size, self.tagset_size) forscores, cur_bp = torch.max(cur_values, 1) cur_bp.masked_fill_(mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size), self.end_tag) back_points.append(cur_bp) pointer = back_points[-1][:, self.end_tag] decode_idx[-1] = pointer for idx in range(len(back_points)-2, -1, -1): back_point = back_points[idx] index = pointer.contiguous().view(-1, 1) pointer = torch.gather(back_point, 1, index).view(-1) decode_idx[idx] = pointer return decode_idx
[docs] def to_spans(self, sequence): """ decode the best path to spans. Parameters ---------- sequence: list, required. the list of best label indexes paths . Returns ------- output: ``set``. A set of chunks contains the position and type of the entities. """ chunks = [] current = None for i, y in enumerate(sequence): label = self.r_y_map[y] if label.startswith('B-'): if current is not None: chunks.append('@'.join(current)) current = [label.replace('B-', ''), '%d' % i] elif label.startswith('S-'): if current is not None: chunks.append('@'.join(current)) current = None base = label.replace('S-', '') chunks.append('@'.join([base, '%d' % i])) elif label.startswith('I-'): if current is not None: base = label.replace('I-', '') if base == current[0]: current.append('%d' % i) else: chunks.append('@'.join(current)) current = [base, '%d' % i] else: current = [label.replace('I-', ''), '%d' % i] elif label.startswith('E-'): if current is not None: base = label.replace('E-', '') if base == current[0]: current.append('%d' % i) chunks.append('@'.join(current)) current = None else: chunks.append('@'.join(current)) current = [base, '%d' % i] chunks.append('@'.join(current)) current = None else: current = [label.replace('E-', ''), '%d' % i] chunks.append('@'.join(current)) current = None else: if current is not None: chunks.append('@'.join(current)) current = None if current is not None: chunks.append('@'.join(current)) return set(chunks)