| from typing import Dict, Any |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
| from PIL import Image |
| import requests |
| from io import BytesIO |
| import torch |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device) |
| self.processor = AutoProcessor.from_pretrained(path) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| |
| if "inputs" not in data: |
| return {"error": "Payload must contain 'inputs' key with 'image' and 'text'."} |
| |
| inputs = data["inputs"] |
| if "image" not in inputs or "text" not in inputs: |
| return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."} |
|
|
| |
| image_data = inputs["image"] |
| if image_data.startswith("http"): |
| response = requests.get(image_data) |
| image = Image.open(BytesIO(response.content)) |
| else: |
| return {"error": "Handler currently supports only URL-based images."} |
|
|
| |
| text_queries = inputs["text"] |
| if isinstance(text_queries, list): |
| text_queries = ". ".join([t.lower().strip() + "." for t in text_queries]) |
|
|
| |
| processed_inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**processed_inputs) |
|
|
| |
| results = self.processor.post_process_grounded_object_detection( |
| outputs, |
| processed_inputs.input_ids, |
| box_threshold=0.4, |
| text_threshold=0.3, |
| target_sizes=[image.size[::-1]] |
| ) |
|
|
| |
| return {"detections": results} |
|
|