| | from fastapi import FastAPI, Request |
| | from requests import Response |
| | from langchain_community.embeddings.ollama import OllamaEmbeddings |
| | import numpy as np |
| | import json |
| |
|
| | app = FastAPI( |
| | title="FetishTest", |
| | description="Game of matching fetish of users", |
| | ) |
| |
|
| | |
| | embed_model = OllamaEmbeddings(model="bge-m3") |
| | with open("standard_character.json", "r") as f: |
| | standard_character = json.load(f) |
| | with open("Q&A.json", "r") as f: |
| | answer2label = json.load(f) |
| |
|
| | |
| | def cosine_similarity(vec1, vec2): |
| | vec1, vec2 = np.array(vec1), np.array(vec2) |
| | return vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) |
| |
|
| | @app.post("/fetish") |
| | async def matching(request: Request): |
| | request = await request.json() |
| | answer = request["answer"] |
| | print(f"user_input: {answer}") |
| | user_labels = [] |
| | for idx, ans in enumerate(answer): |
| | curr_label = answer2label[idx][ans] |
| | if curr_label not in user_labels: |
| | user_labels.append(answer2label[idx][ans]) |
| | user_labels = sorted(user_labels) |
| | print(f"user_labels: {user_labels}") |
| | user_embedding = embed_model.embed_query(" ".join(user_labels)) |
| | matching_dict = {} |
| | for character in standard_character: |
| | sim = cosine_similarity(user_embedding, character["embedding"]) |
| | matching_dict[character["key"]] = sim |
| | |
| | matching_tuple = sorted(matching_dict.items(), key=lambda x: x[1], reverse=True) |
| | |
| | matched = matching_tuple[0][0] |
| | sim = matching_tuple[0][1] |
| | matched_name = standard_character[matched]["name"] |
| | matched_label = standard_character[matched]["label"] |
| | result = { |
| | "result": matched |
| | } |
| | print(f"matched: {matched}") |
| | print(f"{matched_name}: {matched_label} -- score: {sim}") |
| | return result |
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run( |
| | "api:app", |
| | host="0.0.0.0", |
| | port=8002, |
| | loop="asyncio", |
| | workers=8, |
| | limit_concurrency=10, |
| | timeout_keep_alive=60, |
| | access_log=True |
| | ) |