| |
| """ |
| FastAPI application for FunctionGemma with HuggingFace login support. |
| This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860 |
| 修复:增加token计算 |
| """ |
|
|
| import os |
| import sys |
| from pathlib import Path |
| from fastapi import FastAPI |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
| from huggingface_hub import login |
|
|
| |
| model_name = None |
| pipe = None |
| tokenizer = None |
| app = FastAPI(title="FunctionGemma API", version="1.0.0") |
|
|
| def check_and_download_model(): |
| """Check if model exists in cache, if not download it""" |
| global model_name, tokenizer |
| |
| |
| |
| model_name = "unsloth/functiongemma-270m-it" |
| |
| cache_dir = "./my_model_cache" |
| |
| |
| model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}" |
| snapshot_path = model_path / "snapshots" |
| |
| if snapshot_path.exists() and any(snapshot_path.iterdir()): |
| print(f"✓ Model {model_name} already exists in cache") |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
| return model_name, cache_dir |
| |
| print(f"✗ Model {model_name} not found in cache") |
| print("Downloading model...") |
| |
| |
| token = os.getenv("HUGGINGFACE_TOKEN") |
| if token: |
| try: |
| print("Logging in to Hugging Face...") |
| login(token=token) |
| print("✓ HuggingFace login successful!") |
| except Exception as e: |
| print(f"⚠ Login failed: {e}") |
| print("Continuing without login (public models only)") |
| else: |
| print("ℹ No HUGGINGFACE_TOKEN set - using public models only") |
| |
| try: |
| |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
| print("✓ Tokenizer loaded successfully!") |
| |
| |
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) |
| print("✓ Model loaded successfully!") |
| |
| print(f"✓ Model and tokenizer downloaded successfully to {cache_dir}") |
| return model_name, cache_dir |
| |
| except Exception as e: |
| print(f"✗ Error downloading model: {e}") |
| print("\nPossible reasons:") |
| print("1. Model requires authentication - set HUGGINGFACE_TOKEN in .env") |
| print("2. Model is gated and you don't have access") |
| print("3. Network connection issues") |
| sys.exit(1) |
|
|
| def initialize_pipeline(): |
| """Initialize the pipeline with the model""" |
| global pipe, model_name, tokenizer |
| |
| if model_name is None: |
| model_name, _ = check_and_download_model() |
| |
| if tokenizer is None: |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./my_model_cache") |
| |
| print(f"Initializing pipeline with {model_name}...") |
| pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) |
| print("✓ Pipeline initialized successfully!") |
|
|
| |
| @app.get("/") |
| def greet_json(): |
| return { |
| "message": "FunctionGemma API is running!", |
| "model": model_name, |
| "status": "ready" |
| } |
|
|
| @app.get("/health") |
| def health_check(): |
| return {"status": "healthy", "model": model_name} |
|
|
| @app.get("/generate") |
| def generate_text(prompt: str = "Who are you?"): |
| """Generate text using the model""" |
| if pipe is None: |
| initialize_pipeline() |
| |
| messages = [{"role": "user", "content": prompt}] |
| result = pipe(messages, max_new_tokens=1000) |
| return {"response": result[0]["generated_text"]} |
|
|
| @app.post("/chat") |
| def chat_completion(messages: list): |
| """Chat completion endpoint""" |
| if pipe is None: |
| initialize_pipeline() |
| |
| result = pipe(messages, max_new_tokens=200) |
| return {"response": result[0]["generated_text"]} |
|
|
| @app.post("/v1/chat/completions") |
| def openai_chat_completions(request: dict): |
| """ |
| OpenAI-compatible chat completions endpoint |
| Expected request format: |
| { |
| "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
| "messages": [ |
| {"role": "user", "content": "Hello"} |
| ], |
| "max_tokens": 100, |
| "temperature": 0.7 |
| } |
| """ |
| if pipe is None: |
| initialize_pipeline() |
| |
| import time |
| |
| messages = request.get("messages", []) |
| model = request.get("model", model_name) |
| max_tokens = request.get("max_tokens", 1000) |
| temperature = request.get("temperature", 0.7) |
| |
| print('\n\n request') |
| print(request) |
| print('\n\n messages') |
| print(messages) |
| print('\n\n model') |
| print(model) |
| print('\n\n max_tokens') |
| print(max_tokens) |
| print('\n\n temperature') |
| print(temperature) |
| |
| |
| result = pipe( |
| messages, |
| max_new_tokens=max_tokens, |
| |
| ) |
|
|
| result = convert_json_format(result) |
|
|
| |
| completion_id = f"chatcmpl-{int(time.time())}" |
| created = int(time.time()) |
|
|
| return_json = { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": created, |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": result["generations"][0][0]["text"] |
| }, |
| "finish_reason": "stop" |
| } |
| ], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0 |
| } |
| } |
| |
| |
| if tokenizer: |
| prompt_text = "" |
| for message in messages: |
| prompt_text += message.get("content", "") + " " |
| prompt_tokens = len(tokenizer.encode(prompt_text.strip())) |
| return_json["usage"]["prompt_tokens"] = prompt_tokens |
|
|
| |
| if tokenizer and result["generations"]: |
| completion_text = result["generations"][0][0]["text"] |
| completion_tokens = len(tokenizer.encode(completion_text)) |
| return_json["usage"]["completion_tokens"] = completion_tokens |
| |
| return_json["usage"]["total_tokens"] = return_json["usage"]["prompt_tokens"] + return_json["usage"]["completion_tokens"] |
|
|
| print('\n\n return_json') |
| print(return_json) |
| print('return over! \n\n') |
| |
| return return_json |
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize the model when the app starts""" |
| print("=" * 60) |
| print("FunctionGemma FastAPI Server") |
| print("=" * 60) |
| print("Initializing model...") |
| initialize_pipeline() |
| print("\n" + "=" * 60) |
| print("Server ready at http://0.0.0.0:7860") |
| print("Available endpoints:") |
| print(" GET / - Welcome message") |
| print(" GET /health - Health check") |
| print(" GET /generate?prompt=... - Generate text with prompt") |
| print(" POST /chat - Chat completion") |
| print(" POST /v1/chat/completions - OpenAI-compatible endpoint") |
| print("=" * 60 + "\n") |
|
|
| import re |
|
|
| def convert_json_format(input_data): |
| output_generations = [] |
| for item in input_data: |
| generated_text_list = item.get('generated_text', []) |
| |
| assistant_content = "" |
| for message in generated_text_list: |
| if message.get('role') == 'assistant': |
| assistant_content = message.get('content', '') |
| break |
|
|
| |
| clean_content = re.sub(r'<think>.*?</think>\s*', '', assistant_content, flags=re.DOTALL).strip() |
|
|
| output_generations.append([ |
| { |
| "text": clean_content, |
| "generationInfo": { |
| "finish_reason": "stop" |
| } |
| } |
| ]) |
| |
| return {"generations": output_generations} |
|
|