Source code for model_seq.dataset

"""
.. module:: dataset
    :synopsis: dataset for sequence labeling
 
.. moduleauthor:: Liyuan Liu
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
import pickle
import random
import functools
import itertools
from tqdm import tqdm

[docs]class SeqDataset(object): """ Dataset for Sequence Labeling Parameters ---------- dataset : ``list``, required. The encoded dataset (outputs of preprocess scripts). flm_pad : ``int``, required. The pad index for the forward language model. blm_pad : ``int``, required. The pad index for the backward language model. w_pad : ``int``, required. The pad index for the word-level inputs. c_con : ``int``, required. The index of connect character token for character-level inputs. c_pad : ``int``, required. The pad index for the character-level inputs. y_start : ``int``, required. The index of the start label token. y_pad : ``int``, required. The index of the pad label token. y_size : ``int``, required. The size of the tag set. batch_size: ``int``, required. Batch size. """ def __init__(self, dataset: list, flm_pad: int, blm_pad: int, w_pad: int, c_con: int, c_pad: int, y_start: int, y_pad: int, y_size: int, batch_size: int): super(SeqDataset, self).__init__() self.flm_pad = flm_pad self.blm_pad = blm_pad self.w_pad = w_pad self.c_con = c_con self.c_pad = c_pad self.y_pad = y_pad self.y_size = y_size self.y_start = y_start self.batch_size = batch_size self.construct_index(dataset) self.shuffle()
[docs] def shuffle(self): """ shuffle dataset """ random.shuffle(self.shuffle_list)
[docs] def get_tqdm(self, device): """ construct dataset reader and the corresponding tqdm. Parameters ---------- device: ``torch.device``, required. the target device for the dataset loader. """ return tqdm(self.reader(device), mininterval=2, total=self.index_length // self.batch_size, leave=False, file=sys.stdout, ncols=80)
[docs] def construct_index(self, dataset): """ construct index for the dataset. Parameters ---------- dataset: ``list``, required. the encoded dataset (outputs of preprocess scripts). """ for instance in dataset: c_len = [len(tup)+1 for tup in instance[3]] c_ins = [tup for ins in instance[3] for tup in (ins + [self.c_con])] instance[3] = c_ins instance.append(c_len) self.dataset = dataset self.index_length = len(dataset) self.shuffle_list = list(range(0, self.index_length))
[docs] def reader(self, device): """ construct dataset reader. Parameters ---------- device: ``torch.device``, required. the target device for the dataset loader. Returns ------- reader: ``iterator``. A lazy iterable object """ cur_idx = 0 while cur_idx < self.index_length: end_index = min(cur_idx + self.batch_size, self.index_length) batch = [self.dataset[self.shuffle_list[index]] for index in range(cur_idx, end_index)] cur_idx = end_index yield self.batchify(batch, device) self.shuffle()
[docs] def batchify(self, batch, device): """ batchify a batch of data and move to a device. Parameters ---------- batch: ``list``, required. a sample from the encoded dataset (outputs of preprocess scripts). device: ``torch.device``, required. the target device for the dataset loader. """ cur_batch_size = len(batch) char_padded_len = max([len(tup[3]) for tup in batch]) word_padded_len = max([len(tup[0]) for tup in batch]) tmp_batch = [list() for ind in range(11)] for instance_ind in range(cur_batch_size): instance = batch[instance_ind] char_padded_len_ins = char_padded_len - len(instance[3]) word_padded_len_ins = word_padded_len - len(instance[0]) tmp_batch[0].append(instance[3] + [self.c_pad] + [self.c_pad] * char_padded_len_ins) tmp_batch[2].append([self.c_pad] + instance[3][::-1] + [self.c_pad] * char_padded_len_ins) tmp_p = list( itertools.accumulate(instance[5]+[1]+[0]* word_padded_len_ins) ) tmp_batch[1].append([(x - 1) * cur_batch_size + instance_ind for x in tmp_p]) tmp_p = list(itertools.accumulate([1]+instance[5][::-1]))[::-1] + [1]*word_padded_len_ins tmp_batch[3].append([(x - 1) * cur_batch_size + instance_ind for x in tmp_p]) tmp_batch[4].append(instance[0] + [self.flm_pad] + [self.flm_pad] * word_padded_len_ins) tmp_batch[5].append([self.blm_pad] + instance[1][::-1] + [self.blm_pad] * word_padded_len_ins) tmp_p = list(range(len(instance[1]), -1, -1)) + list(range(len(instance[1])+1, word_padded_len+1)) tmp_batch[6].append([x * cur_batch_size + instance_ind for x in tmp_p]) tmp_batch[7].append(instance[2] + [self.w_pad] + [self.w_pad] * word_padded_len_ins) tmp_batch[8].append([self.y_start * self.y_size + instance[4][0]] + [instance[4][ind] * self.y_size + instance[4][ind+1] for ind in range(len(instance[4]) - 1)] + [instance[4][-1] * self.y_size + self.y_pad] + [self.y_pad * self.y_size + self.y_pad] * word_padded_len_ins) tmp_batch[9].append([1] * len(instance[4]) + [1] + [0] * word_padded_len_ins) tmp_batch[10].append(instance[4]) tbt = [torch.LongTensor(v).transpose(0, 1).contiguous() for v in tmp_batch[0:9]] + [torch.ByteTensor(tmp_batch[9]).transpose(0, 1).contiguous()] tbt[1] = tbt[1].view(-1) tbt[3] = tbt[3].view(-1) tbt[6] = tbt[6].view(-1) return [ten.to(device) for ten in tbt] + [tmp_batch[10]]