| import warnings from pathlib import Path import torch.nn from abc import abstractmethod from typing import Union, List from torch.utils.data.dataset import Dataset import flair from flair import file_utils from flair.data import DataPoint, Sentence from flair.datasets import DataLoader from flair.training_utils import Result class Model(torch.nn.Module): """Abstract base class for all downstream task models in Flair, such as SequenceTagger and TextClassifier. Every new type of model must implement these methods.""" @abstractmethod def forward_loss( self, data_points: Union List DataPoint , DataPoint ) -> torch.tensor: """Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training.""" pass @abstractmethod def evaluate( self, sentences: Union List DataPoint , Dataset , out_path: Path = None, embedding_storage_mode: str = "none", ) -> (Result, float): """Evaluates the model. Returns a Result object containing evaluation results and a loss value. Implement this to enable evaluation. :param data_loader: DataLoader that iterates over dataset to be evaluated :param out_path: Optional output path to store predictions :param embedding_storage_mode: =means all embeddings are deleted and freshly recomputed, 'cpu ' means all embeddings are stored on CPU, or 'gpu ' means all embeddings are stored on GPU :return: Returns a Tuple consisting of a Result object and a loss float value """ pass @abstractmethod def _get_state_dict(self): """Returns the state dictionary for this model. Implementing this enables the save() and save_checkpoint() functionality.""" pass @staticmethod @abstractmethod def _init_model_with_state_dict(state): """Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint() functionality.""" pass @staticmethod @abstractmethod def _fetch_model(model_name) -> str: return model_name def save(self, model_file: Union str, Path ): """ Saves the current model to the provided file. :param model_file: the model file """ model_state = self._get_state_dict() torch.save(model_state, str(model_file), pickle_protocol=4) @classmethod def load(cls, model: Union str, Path ): """ Loads the model from the given file. :param model: the model file :return: the loaded text classifier model """ model_file = cls._fetch_model(str(model)) with warnings.catch_warnings(): warnings.filterwarnings("ignore") # load_big_file is a workaround by https: github.com highway11git to load models on some Mac Windows setups # see https: github.com zalandoresearch flair issues 351 f = file_utils.load_big_file(str(model_file)) state = torch.load(f, map_location= 'cpu ') model = cls._init_model_with_state_dict(state) model.eval() model.to(flair.device) return model class LockedDropout(torch.nn.Module): """ Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space. """ def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False): super(LockedDropout, self).__init__() self.dropout_rate = dropout_rate self.batch_first = batch_first self.inplace = inplace def forward(self, x): if not self.training or not self.dropout_rate: return x if not self.batch_first: m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) else: m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate) mask = torch.autograd.Variable(m, requires_grad=False) (1 - self.dropout_rate) mask = mask.expand_as(x) return mask * x def extra_repr(self): inplace_str = ", inplace" if self.inplace else "" return "p={}{}".format(self.dropout_rate, inplace_str) class WordDropout(torch.nn.Module): """ Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space. """ def __init__(self, dropout_rate=0.05, inplace=False): super(WordDropout, self).__init__() self.dropout_rate = dropout_rate self.inplace = inplace def forward(self, x): if not self.training or not self.dropout_rate: return x m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate) mask = torch.autograd.Variable(m, requires_grad=False) return mask * x def extra_repr(self): inplace_str = ", inplace" if self.inplace else "" return "p={}{}".format(self.dropout_rate, inplace_str) |
Комментарии