"""
.. module:: dataset
:synopsis: dataset for language modeling
.. moduleauthor:: Liyuan Liu
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import pickle
import random
from tqdm import tqdm
from torch.utils.data import Dataset
[docs]class EvalDataset(object):
"""
Dataset for Language Modeling
Parameters
----------
dataset : ``list``, required.
The encoded dataset (outputs of preprocess scripts).
sequence_length: ``int``, required.
Sequence Length.
"""
def __init__(self, dataset, sequence_length):
super(EvalDataset, self).__init__()
self.dataset = dataset
self.sequence_length = sequence_length
self.construct_index()
[docs] def get_tqdm(self, device):
"""
construct dataset reader and the corresponding tqdm.
Parameters
----------
device: ``torch.device``, required.
the target device for the dataset loader.
"""
return tqdm(self.reader(device), mininterval=2, total=self.index_length, leave=False, file=sys.stdout, ncols=80)
[docs] def construct_index(self):
"""
construct index for the dataset.
"""
token_per_batch = self.sequence_length
tot_num = len(self.dataset) - 1
res_num = tot_num - tot_num % token_per_batch
self.x = list(torch.unbind(torch.LongTensor(self.dataset[0:res_num]).view(-1, self.sequence_length), 0))
self.y = list(torch.unbind(torch.LongTensor(self.dataset[1:res_num+1]).view(-1, self.sequence_length), 0))
self.x.append(torch.LongTensor(self.dataset[res_num:tot_num]))
self.y.append(torch.LongTensor(self.dataset[res_num+1:tot_num+1]))
self.index_length = len(self.x)
self.cur_idx = 0
[docs] def reader(self, device):
"""
construct dataset reader.
Parameters
----------
device: ``torch.device``, required.
the target device for the dataset loader.
Returns
-------
reader: ``iterator``.
A lazy iterable object
"""
if self.cur_idx == self.index_length:
self.cur_idx = 0
raise StopIteration
word_t = self.x[self.cur_idx].to(device).view(-1, 1)
label_t = self.y[self.cur_idx].to(device).view(-1, 1)
self.cur_idx += 1
yield word_t, label_t
[docs]class LargeDataset(object):
"""
Lazy Dataset for Language Modeling
Parameters
----------
root : ``str``, required.
The root folder for dataset files.
range_idx : ``int``, required.
The maximum file index for the input files (train_*.pk).
batch_size : ``int``, required.
Batch size.
sequence_length: ``int``, required.
Sequence Length.
"""
def __init__(self, root, range_idx, batch_size, sequence_length):
super(LargeDataset, self).__init__()
self.root = root
self.range_idx = range_idx
self.shuffle_list = list(range(0, range_idx))
self.shuffle()
self.batch_size = batch_size
self.sequence_length = sequence_length
self.token_per_batch = self.batch_size * self.sequence_length
self.total_batch_num = -1
[docs] def shuffle(self):
"""
shuffle dataset
"""
random.shuffle(self.shuffle_list)
[docs] def get_tqdm(self, device):
"""
construct dataset reader and the corresponding tqdm.
Parameters
----------
device: ``torch.device``, required.
the target device for the dataset loader.
"""
self.batch_count = 0
self.cur_idx = 0
self.file_idx = 0
self.index_length = 0
if self.total_batch_num <= 0:
return tqdm(self.reader(device), mininterval=2, leave=False, file=sys.stdout).__iter__()
else:
return tqdm(self.reader(device), mininterval=2, total=self.total_batch_num, leave=False, file=sys.stdout, ncols=80).__iter__()
[docs] def reader(self, device):
"""
construct dataset reader.
Parameters
----------
device: ``torch.device``, required.
the target device for the dataset loader.
Returns
-------
reader: ``iterator``.
A lazy iterable object
"""
while self.file_idx < self.range_idx:
self.open_next()
while self.cur_idx < self.index_length:
word_t = self.x[self.cur_idx].to(device)
# label_t = self.y[self.cur_idx].to(device)
label_t = self.y[self.cur_idx].to(device)
self.cur_idx += 1
yield word_t, label_t
self.total_batch_num = self.batch_count
self.shuffle()
[docs] def open_next(self):
"""
Open the next file.
"""
self.dataset = pickle.load(open(self.root + 'train_' + str( self.shuffle_list[self.file_idx])+'.pk', 'rb'))
res_num = len(self.dataset) - 1
res_num = res_num - res_num % self.token_per_batch
self.x = torch.LongTensor(self.dataset[0:res_num]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).transpose_(1, 2).contiguous()
self.y = torch.LongTensor(self.dataset[1:res_num+1]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).transpose_(1, 2).contiguous()
self.index_length = self.x.size(0)
self.cur_idx = 0
self.batch_count += self.index_length
self.file_idx += 1