Spaces:
Runtime error
Runtime error
| #imported from https://github.com/m3hrdadfi/soxan to implement Wav2Vec2 for speech classification | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import ( | |
| Wav2Vec2PreTrainedModel, | |
| Wav2Vec2Model | |
| ) | |
| from src.modeling_outputs import SpeechClassifierOutput | |
| class Wav2Vec2ClassificationHead(nn.Module): | |
| """Head for wav2vec classification task.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.pooling_mode = config.pooling_mode | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.classifier = Wav2Vec2ClassificationHead(config) | |
| self.init_weights() | |
| def freeze_feature_extractor(self): | |
| self.wav2vec2.feature_extractor._freeze_parameters() | |
| def merged_strategy( | |
| self, | |
| hidden_states, | |
| mode="mean" | |
| ): | |
| if mode == "mean": | |
| outputs = torch.mean(hidden_states, dim=1) | |
| elif mode == "sum": | |
| outputs = torch.sum(hidden_states, dim=1) | |
| elif mode == "max": | |
| outputs = torch.max(hidden_states, dim=1)[0] | |
| else: | |
| raise Exception( | |
| "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") | |
| return outputs | |
| def forward( | |
| self, | |
| input_values, | |
| attention_mask=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| labels=None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.wav2vec2( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] | |
| hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) | |
| logits = self.classifier(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SpeechClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |