rubert-toxicity-detector / app_fixed.py
schernykh
fix: Add API auth, rate limiting, and input size limits
6ee6655
"""
RuBERT Toxicity Detection API
FastAPI + Gradio mounted on single port for HuggingFace Spaces compatibility
"""
import os
import json
import time
import torch
import gradio as gr
from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import uvicorn
from typing import List, Optional
from collections import defaultdict
# Security configuration
API_KEY = os.getenv("API_KEY", "") # Set via HuggingFace Spaces secrets
MAX_TEXT_LENGTH = 2000 # Max characters per request
RATE_LIMIT_WINDOW = 60 # Seconds
RATE_LIMIT_MAX_REQUESTS = 30 # Max requests per window per IP
# Simple in-memory rate limiter
rate_limit_store = defaultdict(list)
# Load model and tokenizer
model_name = "sismetanin/rubert-toxic-pikabu-2ch"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Set model to evaluation mode
model.eval()
# Create FastAPI app
app = FastAPI(
title="RuBERT Toxicity Detection API",
description="Detects toxicity in Russian text using RuBERT model",
version="1.0.0"
)
class PredictionRequest(BaseModel):
data: List[str]
def check_rate_limit(client_id: str) -> bool:
"""Check if client has exceeded rate limit"""
now = time.time()
# Clean old entries
rate_limit_store[client_id] = [
t for t in rate_limit_store[client_id]
if now - t < RATE_LIMIT_WINDOW
]
# Check limit
if len(rate_limit_store[client_id]) >= RATE_LIMIT_MAX_REQUESTS:
return False
rate_limit_store[client_id].append(now)
return True
def verify_api_key(authorization: Optional[str]) -> bool:
"""Verify API key if configured"""
if not API_KEY: # No key configured, allow all
return True
if not authorization:
return False
# Support both "Bearer <key>" and plain "<key>"
key = authorization.replace("Bearer ", "").strip()
return key == API_KEY
def classify_text(text):
"""Classify text for toxicity"""
try:
# Tokenize input
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get probabilities for each class
toxic_prob = predictions[0][1].item()
non_toxic_prob = predictions[0][0].item()
# Determine if text is toxic (threshold: 0.5)
is_toxic = toxic_prob > 0.5
result = {
"text": text,
"toxic": is_toxic,
"confidence": toxic_prob if is_toxic else non_toxic_prob,
"toxic_probability": toxic_prob,
"non_toxic_probability": non_toxic_prob,
"label": "TOXIC" if is_toxic else "NON_TOXIC"
}
return result
except Exception as e:
return {
"error": str(e),
"text": text,
"toxic": False,
"confidence": 0.0
}
# API endpoint for bot compatibility
@app.post("/api/predict")
async def predict(
request: PredictionRequest,
authorization: Optional[str] = Header(None),
x_forwarded_for: Optional[str] = Header(None, alias="X-Forwarded-For")
):
"""
API endpoint that matches the bot's expected interface
Accepts: {"data": ["text to analyze"]}
Returns: {"data": [json_string_with_results]}
"""
# Verify API key if configured
if not verify_api_key(authorization):
raise HTTPException(status_code=401, detail="Invalid or missing API key")
# Rate limiting
client_id = x_forwarded_for or "default"
if not check_rate_limit(client_id):
raise HTTPException(status_code=429, detail="Rate limit exceeded")
if not request.data or len(request.data) == 0:
return {"error": "No text provided"}
text = request.data[0]
# Input size limit
if len(text) > MAX_TEXT_LENGTH:
text = text[:MAX_TEXT_LENGTH]
result = classify_text(text)
# Return in the format expected by the bot
return {
"data": [json.dumps(result, ensure_ascii=False)]
}
# Health check endpoint
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "healthy", "model": model_name}
# Gradio interface for web UI
def gradio_classify(text):
"""Wrapper for Gradio interface"""
result = classify_text(text)
if "error" in result:
return f"Error: {result['error']}"
label = "🚫 TOXIC" if result["toxic"] else "✅ NON-TOXIC"
confidence = result["confidence"]
return f"""
**Result**: {label}
**Confidence**: {confidence:.2%}
**Details**:
- Toxic probability: {result['toxic_probability']:.4f}
- Non-toxic probability: {result['non_toxic_probability']:.4f}
"""
# Create Gradio interface
iface = gr.Interface(
fn=gradio_classify,
inputs=gr.Textbox(
lines=3,
placeholder="Enter Russian text to check for toxicity...",
label="Input Text"
),
outputs=gr.Markdown(label="Analysis Result"),
title="🛡️ RuBERT Russian Toxicity Detector",
description="""
This model detects toxic content in Russian text using a fine-tuned RuBERT model.
The model was trained on data from Pikabu and 2ch forums to recognize:
- Offensive language and profanity
- Hate speech and discrimination
- Personal attacks and insults
- Threats and aggressive behavior
**API Access**: POST to `/api/predict` with `{"data": ["your text here"]}`
""",
examples=[
["Это отличная идея! Спасибо за помощь."],
["Ты просто молодец, продолжай в том же духе!"],
["Какая интересная статья, много полезной информации."]
],
theme=gr.themes.Soft()
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")
# Run the app
if __name__ == "__main__":
# Use port 7860 for HuggingFace Spaces
uvicorn.run(app, host="0.0.0.0", port=7860)