Source code for model_seq.evaluator

"""
.. module:: evaluator
    :synopsis: evaluator for sequence labeling
 
.. moduleauthor:: Liyuan Liu
"""
import torch
import numpy as np
import itertools

import model_seq.utils as utils
from torch.autograd import Variable

[docs]class eval_batch: """ Base class for evaluation, provide method to calculate f1 score and accuracy. Parameters ---------- decoder : ``torch.nn.Module``, required. the decoder module, which needs to contain the ``to_span()`` method. """ def __init__(self, decoder): self.decoder = decoder
[docs] def reset(self): """ reset counters. """ self.correct_labels = 0 self.total_labels = 0 self.gold_count = 0 self.guess_count = 0 self.overlap_count = 0
[docs] def calc_f1_batch(self, decoded_data, target_data): """ update statics for f1 score. Parameters ---------- decoded_data: ``torch.LongTensor``, required. the decoded best label index pathes. target_data: ``torch.LongTensor``, required. the golden label index pathes. """ batch_decoded = torch.unbind(decoded_data, 1) for decoded, target in zip(batch_decoded, target_data): length = len(target) best_path = decoded[:length] correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(best_path.numpy(), target) self.correct_labels += correct_labels_i self.total_labels += total_labels_i self.gold_count += gold_count_i self.guess_count += guess_count_i self.overlap_count += overlap_count_i
[docs] def calc_acc_batch(self, decoded_data, target_data): """ update statics for accuracy score. Parameters ---------- decoded_data: ``torch.LongTensor``, required. the decoded best label index pathes. target_data: ``torch.LongTensor``, required. the golden label index pathes. """ batch_decoded = torch.unbind(decoded_data, 1) for decoded, target in zip(batch_decoded, target_data): # remove padding length = len(target) best_path = decoded[:length].numpy() self.total_labels += length self.correct_labels += np.sum(np.equal(best_path, gold))
[docs] def f1_score(self): """ calculate the f1 score based on the inner counter. """ if self.guess_count == 0: return 0.0, 0.0, 0.0, 0.0 precision = self.overlap_count / float(self.guess_count) recall = self.overlap_count / float(self.gold_count) if precision == 0.0 or recall == 0.0: return 0.0, 0.0, 0.0, 0.0 f = 2 * (precision * recall) / (precision + recall) accuracy = float(self.correct_labels) / self.total_labels return f, precision, recall, accuracy
[docs] def acc_score(self): """ calculate the accuracy score based on the inner counter. """ if 0 == self.total_labels: return 0.0 accuracy = float(self.correct_labels) / self.total_labels return accuracy
[docs] def eval_instance(self, best_path, gold): """ Calculate statics to update inner counters for one instance. Parameters ---------- best_path: required. the decoded best label index pathe. gold: required. the golden label index pathes. """ total_labels = len(best_path) correct_labels = np.sum(np.equal(best_path, gold)) gold_chunks = self.decoder.to_spans(gold) gold_count = len(gold_chunks) guess_chunks = self.decoder.to_spans(best_path) guess_count = len(guess_chunks) overlap_chunks = gold_chunks & guess_chunks overlap_count = len(overlap_chunks) return correct_labels, total_labels, gold_count, guess_count, overlap_count
[docs]class eval_wc(eval_batch): """ evaluation class for LD-Net Parameters ---------- decoder : ``torch.nn.Module``, required. the decoder module, which needs to contain the ``to_span()`` and ``decode()`` method. score_type : ``str``, required. whether the f1 score or the accuracy is needed. """ def __init__(self, decoder, score_type): eval_batch.__init__(self, decoder) if 'f' in score_type: self.eval_b = self.calc_f1_batch self.calc_s = self.f1_score else: self.eval_b = self.calc_acc_batch self.calc_s = self.acc_score
[docs] def calc_score(self, seq_model, dataset_loader): """ calculate scores Parameters ---------- seq_model: required. sequence labeling model. dataset_loader: required. the dataset loader. Returns ------- score: ``float``. calculated score. """ seq_model.eval() self.reset() for f_c, f_p, b_c, b_p, flm_w, blm_w, blm_ind, f_w, _, f_y_m, g_y in dataset_loader: scores = seq_model(f_c, f_p, b_c, b_p, flm_w, blm_w, blm_ind, f_w) decoded = self.decoder.decode(scores.data, f_y_m) self.eval_b(decoded, g_y) return self.calc_s()