import os import torch import gradio as gr from nanochat.gpt import GPT, GPTConfig from nanochat.tokenizer import RustBPETokenizer # --- System Initialization --- TOKENIZER_DIR = "." tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR) # Map Special Tokens tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>") tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>") tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>") tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>") tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>") # Model Setup config = GPTConfig( vocab_size=32768, n_layer=12, n_head=6, n_embd=768, sequence_len=2048 ) model = GPT(config) print("Loading model weights...") state_dict = torch.load("model_000971.pt", map_location="cpu") state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=False) model.eval() def predict(message, history): try: tokens = [tokenizer.bos_token_id] user_content = str(message).strip() tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_content) + [tokenizer.user_end_id]) tokens.append(tokenizer.assistant_start_id) with torch.no_grad(): output = model.generate( tokens, max_tokens=512, temperature=0.75, top_k=40 ) generated_text = "" for token in output: token_id = token if isinstance(token, int) else token.item() char = tokenizer.decode([token_id]) if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]): break generated_text += char yield generated_text.strip() except Exception as e: yield f"⚠️ System Error: {str(e)}" # --- UI Customization for Gradio 6.0 --- with gr.Blocks() as demo: gr.ChatInterface( fn=predict, title="⚡ SimpleAI-259M", description="**Fast. Focused. Simple.** A lightweight general intelligence model optimized for reasoning and logic.", examples=[ "Explain neural network?", "Write a python function to calculate the area of a circle.", "Why is the sky blue?" ] ) if __name__ == "__main__": # Moved 'theme' here as requested by the Gradio 6.0 Warning demo.launch( server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate") )