# app.py -- Proto-Cognitive Architecture v5 # Neural Field + Pinned Episodic Memory + Semantic Retrieval + Cognitive Router # + Hebbian Structure Learning + Memory Consolidation + Replay/Dreaming + Competition # # v5 CHANGES (post v4 — structural leap): # # CONTEXT: v4 proved the field is a familiarity detector, not a memory. # The CognitiveRouter uses field resonance to route queries. But W_local # was initialized randomly and NEVER LEARNED — the field had no real # structure. v5 fixes this. # # NEW MECHANISMS: # # 1. HEBBIAN STRUCTURE LEARNING # W_local is now updated via co-activation during think(). # Regions that fire together wire together. This creates learned # association pathways in the field — not random noise. # Includes: sparsification, clamping, diagonal zeroing. # # 2. MEMORY CONSOLIDATION (S → M) # New long-term memory tensor M. Stable, protected regions slowly # consolidate from short-term S into M. M influences field dynamics # via field_step(), pulling state toward consolidated patterns. # This gives the field actual persistence beyond decay. # # 3. REPLAY / DREAMING # Self-training loop that runs without external input. Adds noise, # lets field dynamics + attractors + W_local interact, then reinforces # emergent patterns via Hebbian update. Consolidates after. # Time-bounded for Gradio safety. # # 4. COMPETITION (winner-take-most) # Sharper suppression of weak activations during think(). # Promotes specialization — regions either contribute or get quiet. # # UNCHANGED: # - PinnedEpisodicStore (episodic recall mechanism) # - CognitiveRouter (query routing via resonance + retrieval) # - All baselines (RAG, FieldOnly, ContextBaseline) # - CognitiveResponse dataclass # - Token tracking # - Gradio interface compatibility # # BACKWARD COMPATIBLE: # - load_world() handles missing M tensor gracefully # - save_world() includes M, W_local (already saved), hebbian stats # - Benchmark interfaces unchanged import torch import torch.nn as nn import torch.nn.functional as F import os import time from dataclasses import dataclass, field as dc_field from typing import Optional os.environ["TRANSFORMERS_DISABLE_FLASH_ATTN"] = "1" from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" if DEVICE == "cpu": print("[WARNING] CUDA not available — running on CPU. Install CUDA-enabled PyTorch:") print(" pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 --upgrade") else: n_gpus = torch.cuda.device_count() for _i in range(n_gpus): print(f"[GPU {_i}] {torch.cuda.get_device_name(_i)} " f"({torch.cuda.get_device_properties(_i).total_memory // 1024**3} GB VRAM)") # ========================= # CONFIG # ========================= N_REGIONS = 64 STATE_DIM = 512 LR = 0.01 DECAY = 0.001 PREFIX_LEN = 10 ATTN_TEMP = 0.15 TOPK = 8 FIELD_RATE = 0.05 ATTRACTOR_TH = 0.1 ATTRACTOR_PULL = 0.05 THINK_STEPS = 5 TEACH_THINK_STEPS = 15 MERGE_TH = 0.92 MERGE_EVERY = 10 PREFIX_SCALE = 0.45 INHIBITION_STRENGTH = 0.05 HOMEOSTASIS_TARGET = 0.1 PLASTICITY_SCOPE = 0.01 MAX_EPISODES = 256 SEMANTIC_TOPK = 5 SEMANTIC_THRESHOLD = 0.15 PROTECT_DECAY = 0.9999 # Cognitive Router thresholds (v4) RESONANCE_HIGH = 0.13 RESONANCE_LOW = 0.12 RETRIEVAL_STRONG = 0.40 RETRIEVAL_WEAK = 0.20 # v5: Hebbian structure learning HEBBIAN_LR = 0.002 # learning rate for W_local updates HEBBIAN_PRUNE_TH = 0.0005 # connections below this are zeroed (sparsity) W_LOCAL_MAX = 0.5 # absolute clamp for W_local entries W_LOCAL_DECAY = 0.9995 # slow decay on W_local to prevent saturation # v5: Memory consolidation CONSOLIDATION_RATE = 0.01 # how fast S transfers to M STABILITY_TH = 0.15 # minimum strength for consolidation MEMORY_INFLUENCE = 0.03 # how much M pulls field dynamics # v5: Replay / dreaming REPLAY_STEPS = 5 # steps per replay cycle REPLAY_NOISE = 0.008 # noise magnitude for exploration MAX_REPLAY_CYCLES = 3 # max replay() calls per dream session # v5: Competition COMPETITION_SUPPRESS = 0.05 # non-winners keep this fraction of activation # ========================= # RESPONSE METADATA (v4, unchanged) # ========================= @dataclass class CognitiveResponse: """Structured response from generate() with routing metadata.""" text: str route: str = "unknown" resonance: float = 0.0 retrieval_confidence: float = 0.0 retrieved_facts: list = dc_field(default_factory=list) field_diagnostics: dict = dc_field(default_factory=dict) latency_s: float = 0.0 @property def is_confident(self) -> bool: return self.route == "CONFIDENT" @property def should_verify(self) -> bool: return self.route in ("UNCERTAIN", "DEFER") # ========================= # 1. PINNED EPISODIC STORE (unchanged from v4) # ========================= class PinnedEpisodicStore: """ Non-decaying semantic memory with contradiction/supersession handling. Stores (text, embedding, importance, superseded) tuples. """ SUPERSEDE_TH = 0.82 SUPERSEDE_W = 0.05 def __init__(self, max_episodes: int = MAX_EPISODES): self.max_episodes = max_episodes self.texts: list[str] = [] self.embeds: list[torch.Tensor] = [] self.importance: list[float] = [] self.superseded: list[bool] = [] def __len__(self): return len(self.texts) def pin(self, text: str, embed: torch.Tensor): embed_norm = F.normalize(embed.float().unsqueeze(0), dim=1).squeeze(0) for i, t in enumerate(self.texts): if t == text: self.importance[i] += 1.0 self.superseded[i] = False self.embeds[i] = F.normalize( (0.9 * self.embeds[i] + 0.1 * embed_norm).unsqueeze(0), dim=1 ).squeeze(0) return if self.embeds: stack = torch.stack(self.embeds, dim=0) sims = torch.mv(stack, embed_norm) for i, sim in enumerate(sims.tolist()): if sim >= self.SUPERSEDE_TH and not self.superseded[i]: self.superseded[i] = True self.importance[i] = self.importance[i] * self.SUPERSEDE_W if len(self.texts) >= self.max_episodes: sup_indices = [i for i, s in enumerate(self.superseded) if s] if sup_indices: evict = sup_indices[0] else: evict = self.importance.index(min(self.importance)) self.texts.pop(evict) self.embeds.pop(evict) self.importance.pop(evict) self.superseded.pop(evict) self.texts.append(text) self.embeds.append(embed_norm) self.importance.append(2.0) self.superseded.append(False) def retrieve(self, query_embed: torch.Tensor, topk: int = SEMANTIC_TOPK, threshold: float = SEMANTIC_THRESHOLD) -> list[tuple[str, float]]: if not self.texts: return [] q = F.normalize(query_embed.float().unsqueeze(0), dim=1).squeeze(0) stack = torch.stack(self.embeds, dim=0) sims = torch.mv(stack, q) imp = torch.tensor(self.importance, dtype=torch.float32, device=sims.device) scores = sims * torch.log1p(imp) for i, sup in enumerate(self.superseded): if sup: scores[i] = scores[i] * self.SUPERSEDE_W mask = sims >= threshold if not mask.any(): best = scores.argmax().item() if sims[best].item() >= threshold * 0.5: return [(self.texts[best], sims[best].item())] return [] masked_scores = scores * mask.float() k = min(topk, int(mask.sum().item())) topk_vals, topk_idx = torch.topk(masked_scores, k) results, seen = [], set() for idx, score in zip(topk_idx.tolist(), topk_vals.tolist()): t = self.texts[idx] if t not in seen and sims[idx].item() >= threshold: results.append((t, sims[idx].item())) seen.add(t) return results def retrieval_confidence(self, query_embed: torch.Tensor) -> float: if not self.texts: return 0.0 q = F.normalize(query_embed.float().unsqueeze(0), dim=1).squeeze(0) stack = torch.stack(self.embeds, dim=0) sims = torch.mv(stack, q) imp = torch.tensor(self.importance, dtype=torch.float32, device=sims.device) weighted = sims * torch.log1p(imp) for i, sup in enumerate(self.superseded): if sup: weighted[i] *= self.SUPERSEDE_W return float(weighted.max().item()) if weighted.numel() > 0 else 0.0 def all_texts(self) -> list[str]: return list(self.texts) def clear(self): self.texts.clear() self.embeds.clear() self.importance.clear() self.superseded.clear() def diagnostics(self) -> dict: return { "total": len(self.texts), "active": sum(1 for s in self.superseded if not s), "superseded": sum(self.superseded), } def state_dict(self) -> dict: return { "texts": self.texts, "embeds": [e.cpu() for e in self.embeds], "importance": self.importance, "superseded": self.superseded, } def load_state_dict(self, d: dict): self.texts = d["texts"] self.embeds = [e for e in d["embeds"]] self.importance = d["importance"] self.superseded = d.get("superseded", [False] * len(d["texts"])) # ========================= # 2. NEURAL FIELD WORLD MODEL (v5 — enhanced) # ========================= class WorldModel(nn.Module): def __init__(self, n_regions, state_dim, embed_dim, device): super().__init__() self.n = n_regions self.d = state_dim self.embed_dim = embed_dim self.device = device self.input_proj = nn.Linear(embed_dim, state_dim, bias=False) self.to_prefix = nn.Sequential( nn.Linear(state_dim, state_dim), nn.Tanh(), nn.Linear(state_dim, PREFIX_LEN * embed_dim), ) self.W_local = nn.Parameter( torch.randn(n_regions, n_regions, device=device) * 0.01 ) self.to(device) # Short-term field state self.S = torch.zeros(n_regions, state_dim, device=device) self.strength = torch.zeros(n_regions, device=device) self.memories = [""] * n_regions self.thresholds = torch.ones(n_regions, device=device) * 0.5 self.protected = torch.zeros(n_regions, dtype=torch.bool, device=device) self.step_count = 0 # v5: Long-term consolidated memory self.M = torch.zeros(n_regions, state_dim, device=device) # v5: Hebbian learning stats self.hebbian_updates = 0 self.consolidation_count = 0 self.replay_count = 0 # ----- v5: Hebbian structure learning ----- def hebbian_structure_update(self, attn: torch.Tensor): """ Update W_local based on co-activation patterns. Regions that fire together wire together. Stability controls: - Normalize outer product to prevent explosion - Sparsify: prune weak connections - Clamp: absolute limit on connection strength - Decay: slow decay prevents saturation - Zero diagonal: no self-connections """ with torch.no_grad(): # Co-activation: outer product of attention vector outer = torch.outer(attn, attn) # Normalize to unit energy per update outer_norm = outer.norm() if outer_norm > 1e-8: outer = outer / outer_norm # Apply learning rate and update self.W_local.data += HEBBIAN_LR * outer # Slow decay to prevent saturation over many updates self.W_local.data *= W_LOCAL_DECAY # Sparsify: zero out weak connections self.W_local.data = torch.where( self.W_local.data.abs() > HEBBIAN_PRUNE_TH, self.W_local.data, torch.zeros_like(self.W_local.data) ) # Clamp to prevent runaway self.W_local.data.clamp_(-W_LOCAL_MAX, W_LOCAL_MAX) # Zero diagonal — no self-connections self.W_local.data.fill_diagonal_(0.0) self.hebbian_updates += 1 # ----- v5: Memory consolidation ----- def consolidate(self): """ Transfer stable short-term patterns (S) into long-term memory (M). Only regions with sufficient strength AND protection status consolidate. This means only taught content enters long-term memory — noise and queries don't. M is normalized to prevent unbounded growth. """ with torch.no_grad(): # Stability signal: strength weighted by protection stability = self.strength.clone() stability[self.protected] *= 2.0 # protected regions consolidate faster # Only consolidate stable regions mask = (stability > STABILITY_TH).float().unsqueeze(1) # Exponential moving average: M slowly tracks S for stable regions self.M = self.M + CONSOLIDATION_RATE * mask * (self.S - self.M) # Normalize M to prevent unbounded growth norm = self.M.norm(dim=1, keepdim=True).clamp(min=1e-8) self.M = self.M / norm * torch.tanh(norm) self.consolidation_count += 1 # ----- v5: Replay / dreaming ----- def replay(self, steps: int = REPLAY_STEPS) -> dict: """ Self-training loop without external input. Process: 1. Add small noise for exploration 2. Run field dynamics (W_local interaction + M influence) 3. Apply attractors 4. Hebbian update on emergent activation pattern 5. Consolidate after replay Returns diagnostics dict for UI display. """ with torch.no_grad(): initial_norm = self.S.norm().item() initial_w_density = (self.W_local.data.abs() > HEBBIAN_PRUNE_TH).float().mean().item() for _ in range(steps): # Exploration noise noise = torch.randn_like(self.S) * REPLAY_NOISE self.S = self.S + noise # Let field dynamics run self.field_step() # Competition sharpens activation self.competition() # Attractors pull toward stored patterns self.apply_attractors() # Hebbian update on emergent patterns norms = self.S.norm(dim=1) if norms.max() > 1e-8: attn = F.softmax(norms / ATTN_TEMP, dim=0) self.hebbian_structure_update(attn) # Consolidate after dreaming self.consolidate() self.replay_count += 1 final_norm = self.S.norm().item() final_w_density = (self.W_local.data.abs() > HEBBIAN_PRUNE_TH).float().mean().item() return { "steps": steps, "norm_delta": round(final_norm - initial_norm, 4), "w_density_delta": round(final_w_density - initial_w_density, 6), "total_replays": self.replay_count, } # ----- v5: Competition ----- def competition(self): """ Winner-take-most: top-K regions keep full activation, others are suppressed to a fraction. Promotes specialization. """ with torch.no_grad(): scores = self.S.norm(dim=1) if scores.max() < 1e-8: return # nothing to compete k = min(TOPK, self.n) topk = torch.topk(scores, k) # Build suppression mask mask = torch.full((self.n,), COMPETITION_SUPPRESS, device=self.device) mask[topk.indices] = 1.0 self.S = self.S * mask.unsqueeze(1) def field_step(self): """v5: Now includes long-term memory influence.""" interaction = torch.matmul(self.W_local, self.S) gi = torch.mean(self.S, dim=0, keepdim=True) * INHIBITION_STRENGTH # v5: Long-term memory pulls field toward consolidated patterns memory_pull = MEMORY_INFLUENCE * self.M self.S = torch.tanh(self.S + FIELD_RATE * (interaction - gi) + memory_pull) def apply_attractors(self): mask = self.strength > ATTRACTOR_TH if mask.any(): self.S[mask] = self.S[mask] + ATTRACTOR_PULL * torch.tanh(self.S[mask]) def attend(self, x_embed): x_embed = x_embed.to(self.S.dtype) x_proj = self.input_proj(x_embed) scores = torch.matmul(self.S, x_proj) - self.thresholds scores = scores + 0.2 * self.strength scores = scores + 0.3 * self.protected.float() topk = torch.topk(scores, min(TOPK, self.n)) attn = torch.zeros_like(scores) attn[topk.indices] = F.softmax(topk.values / ATTN_TEMP, dim=0) self.thresholds += PLASTICITY_SCOPE * (attn - HOMEOSTASIS_TARGET) self.thresholds = torch.clamp(self.thresholds, 0.1, 2.0) return attn, x_proj def update(self, attn, x_proj, protect: bool = False): if protect: top_idx = attn.topk(min(TOPK, self.n)).indices self.protected[top_idx] = True delta = LR * torch.outer(attn, x_proj) self.S = self.S + delta self.strength = (self.strength + attn * 0.05) * 0.99 decay_mask = torch.sigmoid(10 * (ATTRACTOR_TH - self.strength)) variable_decay = DECAY * (1.0 + decay_mask * 5.0) protected_decay = 1.0 - DECAY * 0.01 base_decay = 1.0 - variable_decay final_decay = torch.where(self.protected, torch.full_like(base_decay, protected_decay), base_decay) self.S = self.S * final_decay.unsqueeze(1) norm = self.S.norm(dim=1, keepdim=True).clamp(min=1e-8) self.S = self.S / norm * torch.tanh(norm) self.step_count += 1 def think(self, x_embed, steps: int = THINK_STEPS, protect: bool = False): """ v5: Enhanced think loop with competition + Hebbian learning. Per step: 1. field_step() — W_local interaction + M influence 2. competition() — suppress weak activations 3. attend() — compute attention to input 4. update() — update state + decay 5. hebbian_structure_update() — learn W_local from co-activation 6. apply_attractors() After all steps (if protect=True): consolidate S → M """ for _ in range(steps): self.field_step() self.competition() attn, x_proj = self.attend(x_embed) self.update(attn, x_proj, protect=protect) self.hebbian_structure_update(attn) self.apply_attractors() # After teach: consolidate stable patterns into long-term memory if protect: self.consolidate() return attn, x_proj def store_memory(self, text: str, attn: torch.Tensor): top_indices = attn.topk(min(TOPK, self.n)).indices for idx in top_indices: i = idx.item() if not self.memories[i] or self.memories[i] == text: self.memories[i] = text return strengths = self.strength[top_indices] weakest = top_indices[strengths.argmin()].item() self.memories[weakest] = text def merge_similar(self): with torch.no_grad(): norms = self.S.norm(dim=1, keepdim=True).clamp(min=1e-8) S_hat = self.S / norms sim = torch.matmul(S_hat, S_hat.T) merged = set() for i in range(self.n): if i in merged: continue for j in range(i + 1, self.n): if j in merged: continue if sim[i, j] > MERGE_TH: if self.protected[i] or self.protected[j]: continue self.S[i] = (self.S[i] + self.S[j]) / 2.0 if self.strength[j] > self.strength[i]: self.memories[i] = self.memories[j] self.strength[i] = max( self.strength[i].item(), self.strength[j].item()) self.S[j] = 0.0 self.strength[j] = 0.0 self.memories[j] = "" merged.add(j) return len(merged) if merged else 0 def get_focus(self, attn): return torch.sum(attn.unsqueeze(1) * self.S, dim=0) def get_prefix(self, S_focus, embed_dim, ref_embeds=None): S_focus = torch.tanh(S_focus) prefix = self.to_prefix(S_focus).view(PREFIX_LEN, embed_dim) if ref_embeds is not None and ref_embeds.numel() > 0: p_std = prefix.std().clamp(min=1e-8) r_mean = ref_embeds.mean() r_std = ref_embeds.std().clamp(min=1e-8) prefix = (prefix - prefix.mean()) / p_std * r_std + r_mean return prefix * PREFIX_SCALE def recall(self, max_items: int = 3) -> list[str]: """Soft memory recall from strongest attractor regions.""" topk = torch.topk(self.strength, min(max_items * 2, self.n)) recalled, seen = [], set() for idx, s in zip(topk.indices, topk.values): i = idx.item() if self.memories[i] and s.item() > 0.01 and self.memories[i] not in seen: recalled.append(self.memories[i]) seen.add(self.memories[i]) if len(recalled) >= max_items: break return recalled def get_cognitive_resonance(self, x_embed: torch.Tensor) -> float: """Measures field familiarity with the input. Returns 0.0–1.0.""" x_embed = x_embed.to(self.S.dtype) x_proj = self.input_proj(x_embed) raw_scores = torch.matmul(self.S, x_proj) strength_weight = torch.clamp(self.strength, min=0.0) weighted = raw_scores * (1.0 + strength_weight * 5.0) protected_bonus = self.protected.float() * 2.0 weighted = weighted + protected_bonus * raw_scores.clamp(min=0.0) k = min(TOPK, self.n) topk_vals, _ = torch.topk(weighted, k) energy = topk_vals.sum() max_energy = k * (1.0 + self.strength.max().item() * 5.0 + 2.0) max_energy = max(max_energy, 1e-8) resonance = float(torch.sigmoid(energy / max_energy * 4.0 - 2.0).item()) return round(resonance, 4) def diagnostics(self): with torch.no_grad(): w_abs = self.W_local.data.abs() w_nonzero = (w_abs > HEBBIAN_PRUNE_TH).float() n_possible = self.n * (self.n - 1) # exclude diagonal return { "active_regions": (self.strength > 0.01).sum().item(), "attractors": (self.strength > ATTRACTOR_TH).sum().item(), "protected": self.protected.sum().item(), "field_norm": self.S.norm().item(), "strength_max": self.strength.max().item(), "strongest_idx": self.strength.argmax().item(), "total_steps": self.step_count, "n_memories": sum(1 for m in self.memories if m), # v5: Hebbian / consolidation / replay stats "w_local_density": round(w_nonzero.sum().item() / max(n_possible, 1), 4), "w_local_max": round(w_abs.max().item(), 4), "w_local_mean": round(w_abs[w_abs > HEBBIAN_PRUNE_TH].mean().item(), 4) if (w_abs > HEBBIAN_PRUNE_TH).any() else 0.0, "w_local_connections": int(w_nonzero.sum().item()), "m_norm": round(self.M.norm().item(), 4), "m_active_regions": int((self.M.norm(dim=1) > 0.01).sum().item()), "hebbian_updates": self.hebbian_updates, "consolidations": self.consolidation_count, "replays": self.replay_count, } # ========================= # 3. COGNITIVE ROUTER (v4, unchanged) # ========================= class CognitiveRouter: """ Routes queries based on field resonance + retrieval confidence. Route matrix: Retrieval HIGH Retrieval LOW Resonance HIGH → CONFIDENT CAUTIOUS Resonance LOW → UNCERTAIN DEFER """ def __init__(self, resonance_high: float = RESONANCE_HIGH, resonance_low: float = RESONANCE_LOW, retrieval_strong: float = RETRIEVAL_STRONG, retrieval_weak: float = RETRIEVAL_WEAK): self.resonance_high = resonance_high self.resonance_low = resonance_low self.retrieval_strong = retrieval_strong self.retrieval_weak = retrieval_weak self.history: list[dict] = [] def route(self, resonance: float, retrieval_conf: float) -> str: high_res = resonance >= self.resonance_high high_ret = retrieval_conf >= self.retrieval_strong low_ret = retrieval_conf < self.retrieval_weak if high_res and high_ret: route = "CONFIDENT" elif high_res and not high_ret: route = "CAUTIOUS" elif not high_res and not low_ret: route = "UNCERTAIN" else: route = "DEFER" self.history.append({ "resonance": resonance, "retrieval": retrieval_conf, "route": route, }) return route def build_system_prompt(self, route: str, semantic_hits: list, attractor_memories: list) -> str: parts = [] if route == "CONFIDENT": if semantic_hits: facts = list(dict.fromkeys(t for t, _ in semantic_hits)) parts.append("Known facts (high confidence): " + " | ".join(facts)) if attractor_memories: soft = [m for m in attractor_memories if m not in (t for t, _ in semantic_hits)] if soft: parts.append("Additional context: " + " ".join(soft[:2])) elif route == "CAUTIOUS": if semantic_hits: facts = list(dict.fromkeys(t for t, _ in semantic_hits)) parts.append("Possibly relevant facts (verify before stating): " + " | ".join(facts)) parts.append( "Note: Your confidence in this domain is moderate. " "If you are not sure, say so explicitly." ) elif route == "UNCERTAIN": if semantic_hits: facts = list(dict.fromkeys(t for t, _ in semantic_hits)) parts.append("Retrieved facts (low domain familiarity — use cautiously): " + " | ".join(facts)) parts.append( "Note: This topic may be outside your trained knowledge. " "Verify any specific claims before presenting them as fact." ) else: # DEFER parts.append( "You have no stored knowledge about this specific topic. " "Answer only from general knowledge and clearly state that " "you have no specific information stored about this." ) return "\n".join(parts) if parts else "" def get_routing_stats(self) -> dict: if not self.history: return {"total": 0} from collections import Counter routes = Counter(h["route"] for h in self.history) return { "total": len(self.history), "confident": routes.get("CONFIDENT", 0), "cautious": routes.get("CAUTIOUS", 0), "uncertain": routes.get("UNCERTAIN", 0), "defer": routes.get("DEFER", 0), "avg_resonance": round( sum(h["resonance"] for h in self.history) / len(self.history), 4), "avg_retrieval": round( sum(h["retrieval"] for h in self.history) / len(self.history), 4), } # ========================= # 4. SHARED UTILITIES # ========================= @torch.no_grad() def mean_pool_text(tokenizer, model, text: str, device) -> torch.Tensor: ids = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)["input_ids"].to(device) embeds = model.get_input_embeddings()(ids).squeeze(0) return embeds.mean(dim=0) def extract_answer(raw_text: str) -> str: marker = "assistant\n" idx = raw_text.rfind(marker) return raw_text[idx + len(marker):].strip() if idx != -1 else raw_text.strip() # ========================= # 5. HybridLLM — MAIN AGENT (v5) # ========================= class HybridLLM: def __init__(self): print(f"Loading model on {'CUDA' if torch.cuda.is_available() else 'CPU'}...") self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) try: self.model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map="auto", attn_implementation="eager", ) except ValueError as ve: print("[ERROR] Failed to use `device_map` when loading model:", ve) print("Attempting CPU-only fallback load (this may be slow).") try: self.model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map=None, low_cpu_mem_usage=True, ) self.model.eval() self.embed_dim = self.model.get_input_embeddings().weight.shape[1] self.device = next(self.model.parameters()).device print(f" Fallback model loaded -> device={self.device} dtype={next(self.model.parameters()).dtype}") except Exception as e2: print("[ERROR] CPU fallback also failed:", e2) raise self.model.eval() self.embed_dim = self.model.get_input_embeddings().weight.shape[1] self.device = next(self.model.parameters()).device print(f" Model loaded -> device={self.device} dtype={next(self.model.parameters()).dtype}") self.world = WorldModel(N_REGIONS, STATE_DIM, self.embed_dim, self.device) self.episodes = PinnedEpisodicStore(max_episodes=MAX_EPISODES) self.router = CognitiveRouter() self.call_count = 0 self.token_stats = { "teach_input_tokens": 0, "generate_input_tokens": 0, "generate_output_tokens": 0, "generate_calls": 0, "teach_calls": 0, } def embed_token(self, token_id: int) -> torch.Tensor: t = torch.tensor([token_id], device=self.device) return self.model.get_input_embeddings()(t).squeeze(0) @torch.no_grad() def mean_pool_text(self, text: str) -> torch.Tensor: return mean_pool_text(self.tokenizer, self.model, text, self.device) def _process_input(self, text: str, protect: bool = False, think_steps: int = THINK_STEPS): raw_ids = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512 )["input_ids"].to(self.device)[0] x_embed = None for token_id in raw_ids: x_embed = self.embed_token(int(token_id)) attn, x_proj = self.world.attend(x_embed) self.world.update(attn, x_proj, protect=protect) attn, x_proj = self.world.think(x_embed, steps=think_steps, protect=protect) self.world.store_memory(text, attn) S_focus = self.world.get_focus(attn) self.call_count += 1 if self.call_count % MERGE_EVERY == 0: self.world.merge_similar() return attn, S_focus, x_embed def teach(self, text: str, verbose: bool = False, auto_dream: bool = True): n_tok = len(self.tokenizer(text)["input_ids"]) self.token_stats["teach_input_tokens"] += n_tok self.token_stats["teach_calls"] += 1 before_sup = set(self.episodes.texts[i] for i, s in enumerate(self.episodes.superseded) if s) attn, S_focus, _ = self._process_input( text, protect=True, think_steps=TEACH_THINK_STEPS) mean_embed = self.mean_pool_text(text) self.episodes.pin(text, mean_embed) after_sup = set(self.episodes.texts[i] for i, s in enumerate(self.episodes.superseded) if s) newly_superseded = after_sup - before_sup if newly_superseded: for i, mem in enumerate(self.world.memories): if mem in newly_superseded: self.world.memories[i] = "" # v5: Auto-dream after teaching to reinforce structure dream_result = None if auto_dream: dream_result = self.world.replay(steps=REPLAY_STEPS) if verbose: d = self.world.diagnostics() ed = self.episodes.diagnostics() print(f" [Teach] norm={d['field_norm']:.2f} " f"protected={d['protected']}/{N_REGIONS} " f"episodes={ed['total']} (active={ed['active']} " f"superseded={ed['superseded']})" + (f" cleared={newly_superseded}" if newly_superseded else "")) if dream_result: print(f" [Dream] steps={dream_result['steps']} " f"norm_delta={dream_result['norm_delta']:.4f} " f"w_density_delta={dream_result['w_density_delta']:.6f}") return dream_result def dream(self, cycles: int = 1, steps_per_cycle: int = REPLAY_STEPS, verbose: bool = False) -> list[dict]: """ v5: Explicit dream session. Can be called from Gradio UI. Returns list of per-cycle diagnostics. """ cycles = min(cycles, MAX_REPLAY_CYCLES) results = [] for i in range(cycles): result = self.world.replay(steps=steps_per_cycle) results.append(result) if verbose: print(f" [Dream cycle {i+1}/{cycles}] " f"norm_delta={result['norm_delta']:.4f} " f"w_density_delta={result['w_density_delta']:.6f}") return results def generate(self, text: str, max_new_tokens: int = 80, verbose: bool = False) -> str: resp = self.generate_cognitive(text, max_new_tokens=max_new_tokens, verbose=verbose) return resp.text def generate_cognitive(self, text: str, max_new_tokens: int = 80, verbose: bool = False) -> CognitiveResponse: t0 = time.perf_counter() attn, S_focus, x_embed = self._process_input(text, protect=False) query_embed = self.mean_pool_text(text) resonance = self.world.get_cognitive_resonance(x_embed) retrieval_conf = self.episodes.retrieval_confidence(query_embed) route = self.router.route(resonance, retrieval_conf) semantic_hits = self.episodes.retrieve( query_embed, topk=SEMANTIC_TOPK, threshold=SEMANTIC_THRESHOLD) superseded_texts = { self.episodes.texts[i] for i, sup in enumerate(self.episodes.superseded) if sup } attractor_memories = [m for m in self.world.recall() if m not in superseded_texts] system_content = self.router.build_system_prompt( route, semantic_hits, attractor_memories) messages = [] if system_content: messages.append({"role": "system", "content": system_content}) messages.append({"role": "user", "content": text}) if verbose: d = self.world.diagnostics() print(f" [Route] {route} resonance={resonance:.4f} " f"retrieval={retrieval_conf:.4f}") print(f" [Field] norm={d['field_norm']:.2f} " f"active={d['active_regions']}/{N_REGIONS} " f"attractors={d['attractors']} " f"protected={d['protected']}") print(f" [v5] W_local density={d['w_local_density']:.4f} " f"M_norm={d['m_norm']:.4f} " f"hebbian_updates={d['hebbian_updates']}") print(f" [Episodes] total={len(self.episodes)}") if semantic_hits: for t, s in semantic_hits: print(f" hit(score={s:.3f}): {t[:70]}") prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device) input_embeds = self.model.get_input_embeddings()(prompt_ids) if route in ("CONFIDENT", "CAUTIOUS"): prefix = self.world.get_prefix(S_focus, self.embed_dim, ref_embeds=input_embeds.squeeze(0)) prefix_emb = prefix.unsqueeze(0).to(input_embeds.dtype) full_embeds = torch.cat([prefix_emb, input_embeds], dim=1) else: full_embeds = input_embeds attn_mask = torch.ones(full_embeds.shape[:2], device=self.device, dtype=torch.long) output = self.model.generate( inputs_embeds=full_embeds, attention_mask=attn_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, ) n_in = prompt_ids.shape[1] + (PREFIX_LEN if route in ("CONFIDENT", "CAUTIOUS") else 0) n_out = output.shape[1] - n_in self.token_stats["generate_input_tokens"] += n_in self.token_stats["generate_output_tokens"] += max(0, n_out) self.token_stats["generate_calls"] += 1 raw_text = self.tokenizer.decode(output[0], skip_special_tokens=True) elapsed = time.perf_counter() - t0 return CognitiveResponse( text=raw_text, route=route, resonance=resonance, retrieval_confidence=retrieval_conf, retrieved_facts=semantic_hits, field_diagnostics=self.world.diagnostics(), latency_s=round(elapsed, 3), ) def token_report(self) -> dict: s = self.token_stats total_in = s["teach_input_tokens"] + s["generate_input_tokens"] total_out = s["generate_output_tokens"] return { **s, "total_input_tokens": total_in, "total_output_tokens": total_out, "avg_input_per_gen": round(s["generate_input_tokens"] / max(1, s["generate_calls"]), 1), "avg_output_per_gen": round(s["generate_output_tokens"] / max(1, s["generate_calls"]), 1), } def reset_world(self): self.world.S.zero_() self.world.strength.zero_() self.world.step_count = 0 self.world.memories = [""] * self.world.n self.world.thresholds.fill_(0.5) self.world.protected.fill_(False) # v5: reset new state self.world.M.zero_() self.world.W_local.data = torch.randn_like(self.world.W_local.data) * 0.01 self.world.W_local.data.fill_diagonal_(0.0) self.world.hebbian_updates = 0 self.world.consolidation_count = 0 self.world.replay_count = 0 self.episodes.clear() self.router.history.clear() self.call_count = 0 for k in self.token_stats: self.token_stats[k] = 0 def save_world(self, path: str = "world.pt"): torch.save({ "S": self.world.S, "strength": self.world.strength, "W_local": self.world.W_local, "step_count": self.world.step_count, "memories": self.world.memories, "protected": self.world.protected, "thresholds": self.world.thresholds, "episodes": self.episodes.state_dict(), # v5 additions "M": self.world.M, "hebbian_updates": self.world.hebbian_updates, "consolidation_count": self.world.consolidation_count, "replay_count": self.world.replay_count, "version": 5, }, path) print(f" [Saved world -> {path}] episodes={len(self.episodes)} version=5") def load_world(self, path: str = "world.pt"): if os.path.exists(path): data = torch.load(path, map_location=self.device, weights_only=False) self.world.S = data["S"].to(self.device) self.world.strength = data["strength"].to(self.device) if "W_local" in data: self.world.W_local = nn.Parameter(data["W_local"].to(self.device)) if "step_count" in data: self.world.step_count = data["step_count"] if "memories" in data: self.world.memories = data["memories"] if "protected" in data: self.world.protected = data["protected"].to(self.device) if "thresholds" in data: self.world.thresholds = data["thresholds"].to(self.device) if "episodes" in data: self.episodes.load_state_dict(data["episodes"]) self.episodes.embeds = [ e.to(self.device) for e in self.episodes.embeds] # v5: load new state (backward compatible) if "M" in data: self.world.M = data["M"].to(self.device) else: print(" [v5 upgrade: M initialized to zero]") if "hebbian_updates" in data: self.world.hebbian_updates = data["hebbian_updates"] if "consolidation_count" in data: self.world.consolidation_count = data["consolidation_count"] if "replay_count" in data: self.world.replay_count = data["replay_count"] ver = data.get("version", 4) print(f" [Loaded world <- {path}] episodes={len(self.episodes)} version={ver}") else: print(" [No saved world found]") # ========================= # 6. RAG BASELINE (unchanged) # ========================= class RAGBaseline: """Pure semantic retrieval baseline. No Hebbian field, no soft prefix.""" def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model self.device = next(model.parameters()).device self.episodes = PinnedEpisodicStore(max_episodes=MAX_EPISODES) @torch.no_grad() def _mean_pool(self, text: str) -> torch.Tensor: return mean_pool_text(self.tokenizer, self.model, text, self.device) def reset_world(self): self.episodes.clear() def teach(self, text: str, verbose: bool = False): embed = self._mean_pool(text) self.episodes.pin(text, embed) if verbose: print(f" [RAG teach] pinned: '{text[:60]}' total={len(self.episodes)}") def generate(self, text: str, max_new_tokens: int = 80, verbose: bool = False) -> str: query_embed = self._mean_pool(text) hits = self.episodes.retrieve( query_embed, topk=SEMANTIC_TOPK, threshold=SEMANTIC_THRESHOLD) messages = [] if hits: facts = list(dict.fromkeys(t for t, _ in hits)) messages.append({"role": "system", "content": "Known facts: " + " | ".join(facts)}) messages.append({"role": "user", "content": text}) prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device) attn_mask = torch.ones_like(ids) with torch.no_grad(): out = self.model.generate( ids, attention_mask=attn_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9) return self.tokenizer.decode(out[0], skip_special_tokens=True) # ========================= # 7. FIELD ONLY — ABLATION (unchanged) # ========================= class FieldOnly: """Ablation: Hebbian Neural Field + Soft Prefix ONLY. No episodic store.""" def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model self.device = next(model.parameters()).device self.embed_dim = model.get_input_embeddings().weight.shape[1] self.world = WorldModel(N_REGIONS, STATE_DIM, self.embed_dim, self.device) self.call_count = 0 def reset_world(self): self.world.S.zero_() self.world.strength.zero_() self.world.step_count = 0 self.world.memories = [""] * self.world.n self.world.thresholds.fill_(0.5) self.world.protected.fill_(False) self.world.M.zero_() self.world.W_local.data = torch.randn_like(self.world.W_local.data) * 0.01 self.world.W_local.data.fill_diagonal_(0.0) self.call_count = 0 def _embed_token(self, token_id: int) -> torch.Tensor: t = torch.tensor([token_id], device=self.device) return self.model.get_input_embeddings()(t).squeeze(0) def _process_input(self, text: str, protect: bool = False, think_steps: int = THINK_STEPS): raw_ids = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512 )["input_ids"].to(self.device)[0] x_embed = None for token_id in raw_ids: x_embed = self._embed_token(int(token_id)) attn, x_proj = self.world.attend(x_embed) self.world.update(attn, x_proj, protect=protect) attn, x_proj = self.world.think(x_embed, steps=think_steps, protect=protect) self.world.store_memory(text, attn) S_focus = self.world.get_focus(attn) self.call_count += 1 if self.call_count % MERGE_EVERY == 0: self.world.merge_similar() return attn, S_focus def teach(self, text: str, verbose: bool = False): self._process_input(text, protect=True, think_steps=TEACH_THINK_STEPS) if verbose: d = self.world.diagnostics() print(f" [FieldOnly teach] norm={d['field_norm']:.2f} " f"protected={d['protected']}/{N_REGIONS}") def generate(self, text: str, max_new_tokens: int = 80, verbose: bool = False) -> str: attn, S_focus = self._process_input(text, protect=False) messages = [{"role": "user", "content": text}] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device) input_embeds = self.model.get_input_embeddings()(prompt_ids) prefix = self.world.get_prefix( S_focus, self.embed_dim, ref_embeds=input_embeds.squeeze(0)) prefix_emb = prefix.unsqueeze(0).to(input_embeds.dtype) full_embeds = torch.cat([prefix_emb, input_embeds], dim=1) attn_mask = torch.ones( full_embeds.shape[:2], device=self.device, dtype=torch.long) with torch.no_grad(): output = self.model.generate( inputs_embeds=full_embeds, attention_mask=attn_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, ) return self.tokenizer.decode(output[0], skip_special_tokens=True) # ========================= # 8. CONTEXT WINDOW BASELINE (unchanged) # ========================= class ContextBaseline: """Naive context-stuffing baseline.""" def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model self.device = next(model.parameters()).device self._history: list[dict] = [] def reset_world(self): self._history.clear() def teach(self, text: str, verbose: bool = False): self._history.append({"role": "user", "content": text}) self._history.append({"role": "assistant", "content": "(acknowledged)"}) def generate(self, text: str, max_new_tokens: int = 80, verbose: bool = False) -> str: messages = list(self._history) + [{"role": "user", "content": text}] while len(messages) > 2: prompt_check = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) if len(self.tokenizer(prompt_check)["input_ids"]) < 1800: break messages = messages[2:] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device) attn_mask = torch.ones_like(ids) with torch.no_grad(): out = self.model.generate( ids, attention_mask=attn_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9) decoded = self.tokenizer.decode(out[0], skip_special_tokens=True) self._history.append({"role": "user", "content": text}) self._history.append({"role": "assistant", "content": decoded[-200:]}) if len(self._history) > 20: self._history = self._history[-20:] return decoded # ========================= # MAIN LOOP (v5) # ========================= def main(): agent = HybridLLM() print() print("=" * 60) print(" Proto-Cognitive Hybrid LLM v5") print(" Neural Field | Episodic Memory | Cognitive Router") print(" + Hebbian Learning | Consolidation | Replay") print("=" * 60) print(" Commands: save | load | clear | stats | episodes | route") print(" dream | dream:N | exit") print(" Prefix teach: to encode a fact: teach: ") print() while True: try: raw = input("YOU: ").strip() except EOFError: break if not raw: continue text = raw if text.lower() == "exit": break elif text.lower() == "save": agent.save_world() elif text.lower() == "load": agent.load_world() elif text.lower() == "clear": agent.reset_world() print(" [World cleared]") elif text.lower() == "route": rs = agent.router.get_routing_stats() print(f" Routing stats: {rs}") elif text.lower().startswith("dream"): parts = text.split(":") cycles = int(parts[1]) if len(parts) > 1 and parts[1].strip().isdigit() else 1 results = agent.dream(cycles=cycles, verbose=True) d = agent.world.diagnostics() print(f" [Post-dream] W_local density={d['w_local_density']:.4f} " f"M_norm={d['m_norm']:.4f} " f"connections={d['w_local_connections']}") elif text.lower() == "stats": d = agent.world.diagnostics() print(f" Field norm : {d['field_norm']:.4f}") print(f" Active regions : {d['active_regions']}/{N_REGIONS}") print(f" Attractors : {d['attractors']}") print(f" Protected : {d['protected']}") print(f" Memories : {d['n_memories']}") print(f" Strongest : idx={d['strongest_idx']} ({d['strength_max']:.4f})") print(f" Total steps : {d['total_steps']}") print(f" --- v5 ---") print(f" W_local density: {d['w_local_density']:.4f} ({d['w_local_connections']} connections)") print(f" W_local max : {d['w_local_max']:.4f} mean={d['w_local_mean']:.4f}") print(f" M norm : {d['m_norm']:.4f} ({d['m_active_regions']} active regions)") print(f" Hebbian updates: {d['hebbian_updates']}") print(f" Consolidations : {d['consolidations']}") print(f" Replay cycles : {d['replays']}") ed = agent.episodes.diagnostics() print(f" Pinned episodes: {ed['total']} (active={ed['active']} superseded={ed['superseded']})") rs = agent.router.get_routing_stats() print(f" Router stats : {rs}") elif text.lower() == "episodes": print(f" Pinned episodes ({len(agent.episodes)}):") for i, (t, imp) in enumerate( zip(agent.episodes.texts, agent.episodes.importance)): print(f" [{i:3d}] imp={imp:.1f} {t[:80]}") elif text.lower().startswith("teach:"): fact = text[6:].strip() if fact: agent.teach(fact, verbose=True) print(f" [Encoded + pinned: '{fact[:60]}']") else: resp = agent.generate_cognitive(text, verbose=True) answer = extract_answer(resp.text) print(f"\n [{resp.route}] resonance={resp.resonance:.3f} " f"retrieval={resp.retrieval_confidence:.3f}") print(f"\nAI: {answer}") print("-" * 60) if __name__ == "__main__": main()