"""WangchanBERTa for Aspect-Based Sentiment Analysis""" import torch import torch.nn as nn from transformers import AutoModel from typing import Optional, Dict class WangchanBERTaForABSA(nn.Module): """ WangchanBERTa for Aspect-Based Sentiment Analysis Predicts sentiment (none/positive/neutral/negative) for 8 aspects: - equipment (อุปกรณ์) - staff (โค้ช/พนักงาน) - cleanliness (ความสะอาด) - atmosphere (บรรยากาศ) - price (ราคา) - location (ทำเลที่ตั้ง) - programs (คลาส/โปรแกรม) - amenities (สิ่งอำนวยความสะดวก) """ def __init__( self, model_name: str = "airesearch/wangchanberta-base-att-spm-uncased", num_aspects: int = 8, num_sentiments: int = 4, dropout_rate: float = 0.1 ): super().__init__() self.num_aspects = num_aspects self.num_sentiments = num_sentiments self.bert = AutoModel.from_pretrained(model_name) self.hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(dropout_rate) self.shared_layer = nn.Linear(self.hidden_size, self.hidden_size) self.activation = nn.GELU() self.aspect_classifiers = nn.ModuleList([ nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size // 2), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(self.hidden_size // 2, num_sentiments) ) for _ in range(num_aspects) ]) self.class_weights = torch.tensor([0.5, 1.0, 1.0, 1.0]) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] shared_repr = self.activation(self.shared_layer(self.dropout(cls_output))) logits = torch.stack([clf(shared_repr) for clf in self.aspect_classifiers], dim=1) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(weight=self.class_weights.to(logits.device)) loss = sum(loss_fct(logits[:, i, :], labels[:, i]) for i in range(self.num_aspects)) / self.num_aspects return {"loss": loss, "logits": logits}