| | """ |
| | RuBERT Toxicity Detection API |
| | FastAPI + Gradio mounted on single port for HuggingFace Spaces compatibility |
| | """ |
| |
|
| | import os |
| | import json |
| | import time |
| | import torch |
| | import gradio as gr |
| | from fastapi import FastAPI, Header, HTTPException |
| | from pydantic import BaseModel |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import uvicorn |
| | from typing import List, Optional |
| | from collections import defaultdict |
| |
|
| | |
| | API_KEY = os.getenv("API_KEY", "") |
| | MAX_TEXT_LENGTH = 2000 |
| | RATE_LIMIT_WINDOW = 60 |
| | RATE_LIMIT_MAX_REQUESTS = 30 |
| |
|
| | |
| | rate_limit_store = defaultdict(list) |
| |
|
| | |
| | model_name = "sismetanin/rubert-toxic-pikabu-2ch" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForSequenceClassification.from_pretrained(model_name) |
| |
|
| | |
| | model.eval() |
| |
|
| | |
| | app = FastAPI( |
| | title="RuBERT Toxicity Detection API", |
| | description="Detects toxicity in Russian text using RuBERT model", |
| | version="1.0.0" |
| | ) |
| |
|
| | class PredictionRequest(BaseModel): |
| | data: List[str] |
| |
|
| | def check_rate_limit(client_id: str) -> bool: |
| | """Check if client has exceeded rate limit""" |
| | now = time.time() |
| | |
| | rate_limit_store[client_id] = [ |
| | t for t in rate_limit_store[client_id] |
| | if now - t < RATE_LIMIT_WINDOW |
| | ] |
| | |
| | if len(rate_limit_store[client_id]) >= RATE_LIMIT_MAX_REQUESTS: |
| | return False |
| | rate_limit_store[client_id].append(now) |
| | return True |
| |
|
| | def verify_api_key(authorization: Optional[str]) -> bool: |
| | """Verify API key if configured""" |
| | if not API_KEY: |
| | return True |
| | if not authorization: |
| | return False |
| | |
| | key = authorization.replace("Bearer ", "").strip() |
| | return key == API_KEY |
| |
|
| | def classify_text(text): |
| | """Classify text for toxicity""" |
| | try: |
| | |
| | inputs = tokenizer( |
| | text, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512, |
| | padding=True |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| |
|
| | |
| | toxic_prob = predictions[0][1].item() |
| | non_toxic_prob = predictions[0][0].item() |
| |
|
| | |
| | is_toxic = toxic_prob > 0.5 |
| |
|
| | result = { |
| | "text": text, |
| | "toxic": is_toxic, |
| | "confidence": toxic_prob if is_toxic else non_toxic_prob, |
| | "toxic_probability": toxic_prob, |
| | "non_toxic_probability": non_toxic_prob, |
| | "label": "TOXIC" if is_toxic else "NON_TOXIC" |
| | } |
| |
|
| | return result |
| |
|
| | except Exception as e: |
| | return { |
| | "error": str(e), |
| | "text": text, |
| | "toxic": False, |
| | "confidence": 0.0 |
| | } |
| |
|
| | |
| | @app.post("/api/predict") |
| | async def predict( |
| | request: PredictionRequest, |
| | authorization: Optional[str] = Header(None), |
| | x_forwarded_for: Optional[str] = Header(None, alias="X-Forwarded-For") |
| | ): |
| | """ |
| | API endpoint that matches the bot's expected interface |
| | Accepts: {"data": ["text to analyze"]} |
| | Returns: {"data": [json_string_with_results]} |
| | """ |
| | |
| | if not verify_api_key(authorization): |
| | raise HTTPException(status_code=401, detail="Invalid or missing API key") |
| |
|
| | |
| | client_id = x_forwarded_for or "default" |
| | if not check_rate_limit(client_id): |
| | raise HTTPException(status_code=429, detail="Rate limit exceeded") |
| |
|
| | if not request.data or len(request.data) == 0: |
| | return {"error": "No text provided"} |
| |
|
| | text = request.data[0] |
| |
|
| | |
| | if len(text) > MAX_TEXT_LENGTH: |
| | text = text[:MAX_TEXT_LENGTH] |
| |
|
| | result = classify_text(text) |
| |
|
| | |
| | return { |
| | "data": [json.dumps(result, ensure_ascii=False)] |
| | } |
| |
|
| | |
| | @app.get("/health") |
| | async def health(): |
| | """Health check endpoint""" |
| | return {"status": "healthy", "model": model_name} |
| |
|
| | |
| | def gradio_classify(text): |
| | """Wrapper for Gradio interface""" |
| | result = classify_text(text) |
| |
|
| | if "error" in result: |
| | return f"Error: {result['error']}" |
| |
|
| | label = "🚫 TOXIC" if result["toxic"] else "✅ NON-TOXIC" |
| | confidence = result["confidence"] |
| |
|
| | return f""" |
| | **Result**: {label} |
| | **Confidence**: {confidence:.2%} |
| | |
| | **Details**: |
| | - Toxic probability: {result['toxic_probability']:.4f} |
| | - Non-toxic probability: {result['non_toxic_probability']:.4f} |
| | """ |
| |
|
| | |
| | iface = gr.Interface( |
| | fn=gradio_classify, |
| | inputs=gr.Textbox( |
| | lines=3, |
| | placeholder="Enter Russian text to check for toxicity...", |
| | label="Input Text" |
| | ), |
| | outputs=gr.Markdown(label="Analysis Result"), |
| | title="🛡️ RuBERT Russian Toxicity Detector", |
| | description=""" |
| | This model detects toxic content in Russian text using a fine-tuned RuBERT model. |
| | |
| | The model was trained on data from Pikabu and 2ch forums to recognize: |
| | - Offensive language and profanity |
| | - Hate speech and discrimination |
| | - Personal attacks and insults |
| | - Threats and aggressive behavior |
| | |
| | **API Access**: POST to `/api/predict` with `{"data": ["your text here"]}` |
| | """, |
| | examples=[ |
| | ["Это отличная идея! Спасибо за помощь."], |
| | ["Ты просто молодец, продолжай в том же духе!"], |
| | ["Какая интересная статья, много полезной информации."] |
| | ], |
| | theme=gr.themes.Soft() |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, iface, path="/") |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |