Chris4K's picture
Upload 4 files
c7ac419 verified
# app.py -- Proto-Cognitive Architecture v5
# Neural Field + Pinned Episodic Memory + Semantic Retrieval + Cognitive Router
# + Hebbian Structure Learning + Memory Consolidation + Replay/Dreaming + Competition
#
# v5 CHANGES (post v4 — structural leap):
#
# CONTEXT: v4 proved the field is a familiarity detector, not a memory.
# The CognitiveRouter uses field resonance to route queries. But W_local
# was initialized randomly and NEVER LEARNED — the field had no real
# structure. v5 fixes this.
#
# NEW MECHANISMS:
#
# 1. HEBBIAN STRUCTURE LEARNING
# W_local is now updated via co-activation during think().
# Regions that fire together wire together. This creates learned
# association pathways in the field — not random noise.
# Includes: sparsification, clamping, diagonal zeroing.
#
# 2. MEMORY CONSOLIDATION (S → M)
# New long-term memory tensor M. Stable, protected regions slowly
# consolidate from short-term S into M. M influences field dynamics
# via field_step(), pulling state toward consolidated patterns.
# This gives the field actual persistence beyond decay.
#
# 3. REPLAY / DREAMING
# Self-training loop that runs without external input. Adds noise,
# lets field dynamics + attractors + W_local interact, then reinforces
# emergent patterns via Hebbian update. Consolidates after.
# Time-bounded for Gradio safety.
#
# 4. COMPETITION (winner-take-most)
# Sharper suppression of weak activations during think().
# Promotes specialization — regions either contribute or get quiet.
#
# UNCHANGED:
# - PinnedEpisodicStore (episodic recall mechanism)
# - CognitiveRouter (query routing via resonance + retrieval)
# - All baselines (RAG, FieldOnly, ContextBaseline)
# - CognitiveResponse dataclass
# - Token tracking
# - Gradio interface compatibility
#
# BACKWARD COMPATIBLE:
# - load_world() handles missing M tensor gracefully
# - save_world() includes M, W_local (already saved), hebbian stats
# - Benchmark interfaces unchanged
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
from dataclasses import dataclass, field as dc_field
from typing import Optional
os.environ["TRANSFORMERS_DISABLE_FLASH_ATTN"] = "1"
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cpu":
print("[WARNING] CUDA not available — running on CPU. Install CUDA-enabled PyTorch:")
print(" pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 --upgrade")
else:
n_gpus = torch.cuda.device_count()
for _i in range(n_gpus):
print(f"[GPU {_i}] {torch.cuda.get_device_name(_i)} "
f"({torch.cuda.get_device_properties(_i).total_memory // 1024**3} GB VRAM)")
# =========================
# CONFIG
# =========================
N_REGIONS = 64
STATE_DIM = 512
LR = 0.01
DECAY = 0.001
PREFIX_LEN = 10
ATTN_TEMP = 0.15
TOPK = 8
FIELD_RATE = 0.05
ATTRACTOR_TH = 0.1
ATTRACTOR_PULL = 0.05
THINK_STEPS = 5
TEACH_THINK_STEPS = 15
MERGE_TH = 0.92
MERGE_EVERY = 10
PREFIX_SCALE = 0.45
INHIBITION_STRENGTH = 0.05
HOMEOSTASIS_TARGET = 0.1
PLASTICITY_SCOPE = 0.01
MAX_EPISODES = 256
SEMANTIC_TOPK = 5
SEMANTIC_THRESHOLD = 0.15
PROTECT_DECAY = 0.9999
# Cognitive Router thresholds (v4)
RESONANCE_HIGH = 0.13
RESONANCE_LOW = 0.12
RETRIEVAL_STRONG = 0.40
RETRIEVAL_WEAK = 0.20
# v5: Hebbian structure learning
HEBBIAN_LR = 0.002 # learning rate for W_local updates
HEBBIAN_PRUNE_TH = 0.0005 # connections below this are zeroed (sparsity)
W_LOCAL_MAX = 0.5 # absolute clamp for W_local entries
W_LOCAL_DECAY = 0.9995 # slow decay on W_local to prevent saturation
# v5: Memory consolidation
CONSOLIDATION_RATE = 0.01 # how fast S transfers to M
STABILITY_TH = 0.15 # minimum strength for consolidation
MEMORY_INFLUENCE = 0.03 # how much M pulls field dynamics
# v5: Replay / dreaming
REPLAY_STEPS = 5 # steps per replay cycle
REPLAY_NOISE = 0.008 # noise magnitude for exploration
MAX_REPLAY_CYCLES = 3 # max replay() calls per dream session
# v5: Competition
COMPETITION_SUPPRESS = 0.05 # non-winners keep this fraction of activation
# =========================
# RESPONSE METADATA (v4, unchanged)
# =========================
@dataclass
class CognitiveResponse:
"""Structured response from generate() with routing metadata."""
text: str
route: str = "unknown"
resonance: float = 0.0
retrieval_confidence: float = 0.0
retrieved_facts: list = dc_field(default_factory=list)
field_diagnostics: dict = dc_field(default_factory=dict)
latency_s: float = 0.0
@property
def is_confident(self) -> bool:
return self.route == "CONFIDENT"
@property
def should_verify(self) -> bool:
return self.route in ("UNCERTAIN", "DEFER")
# =========================
# 1. PINNED EPISODIC STORE (unchanged from v4)
# =========================
class PinnedEpisodicStore:
"""
Non-decaying semantic memory with contradiction/supersession handling.
Stores (text, embedding, importance, superseded) tuples.
"""
SUPERSEDE_TH = 0.82
SUPERSEDE_W = 0.05
def __init__(self, max_episodes: int = MAX_EPISODES):
self.max_episodes = max_episodes
self.texts: list[str] = []
self.embeds: list[torch.Tensor] = []
self.importance: list[float] = []
self.superseded: list[bool] = []
def __len__(self):
return len(self.texts)
def pin(self, text: str, embed: torch.Tensor):
embed_norm = F.normalize(embed.float().unsqueeze(0), dim=1).squeeze(0)
for i, t in enumerate(self.texts):
if t == text:
self.importance[i] += 1.0
self.superseded[i] = False
self.embeds[i] = F.normalize(
(0.9 * self.embeds[i] + 0.1 * embed_norm).unsqueeze(0), dim=1
).squeeze(0)
return
if self.embeds:
stack = torch.stack(self.embeds, dim=0)
sims = torch.mv(stack, embed_norm)
for i, sim in enumerate(sims.tolist()):
if sim >= self.SUPERSEDE_TH and not self.superseded[i]:
self.superseded[i] = True
self.importance[i] = self.importance[i] * self.SUPERSEDE_W
if len(self.texts) >= self.max_episodes:
sup_indices = [i for i, s in enumerate(self.superseded) if s]
if sup_indices:
evict = sup_indices[0]
else:
evict = self.importance.index(min(self.importance))
self.texts.pop(evict)
self.embeds.pop(evict)
self.importance.pop(evict)
self.superseded.pop(evict)
self.texts.append(text)
self.embeds.append(embed_norm)
self.importance.append(2.0)
self.superseded.append(False)
def retrieve(self, query_embed: torch.Tensor, topk: int = SEMANTIC_TOPK,
threshold: float = SEMANTIC_THRESHOLD) -> list[tuple[str, float]]:
if not self.texts:
return []
q = F.normalize(query_embed.float().unsqueeze(0), dim=1).squeeze(0)
stack = torch.stack(self.embeds, dim=0)
sims = torch.mv(stack, q)
imp = torch.tensor(self.importance, dtype=torch.float32, device=sims.device)
scores = sims * torch.log1p(imp)
for i, sup in enumerate(self.superseded):
if sup:
scores[i] = scores[i] * self.SUPERSEDE_W
mask = sims >= threshold
if not mask.any():
best = scores.argmax().item()
if sims[best].item() >= threshold * 0.5:
return [(self.texts[best], sims[best].item())]
return []
masked_scores = scores * mask.float()
k = min(topk, int(mask.sum().item()))
topk_vals, topk_idx = torch.topk(masked_scores, k)
results, seen = [], set()
for idx, score in zip(topk_idx.tolist(), topk_vals.tolist()):
t = self.texts[idx]
if t not in seen and sims[idx].item() >= threshold:
results.append((t, sims[idx].item()))
seen.add(t)
return results
def retrieval_confidence(self, query_embed: torch.Tensor) -> float:
if not self.texts:
return 0.0
q = F.normalize(query_embed.float().unsqueeze(0), dim=1).squeeze(0)
stack = torch.stack(self.embeds, dim=0)
sims = torch.mv(stack, q)
imp = torch.tensor(self.importance, dtype=torch.float32, device=sims.device)
weighted = sims * torch.log1p(imp)
for i, sup in enumerate(self.superseded):
if sup:
weighted[i] *= self.SUPERSEDE_W
return float(weighted.max().item()) if weighted.numel() > 0 else 0.0
def all_texts(self) -> list[str]:
return list(self.texts)
def clear(self):
self.texts.clear()
self.embeds.clear()
self.importance.clear()
self.superseded.clear()
def diagnostics(self) -> dict:
return {
"total": len(self.texts),
"active": sum(1 for s in self.superseded if not s),
"superseded": sum(self.superseded),
}
def state_dict(self) -> dict:
return {
"texts": self.texts,
"embeds": [e.cpu() for e in self.embeds],
"importance": self.importance,
"superseded": self.superseded,
}
def load_state_dict(self, d: dict):
self.texts = d["texts"]
self.embeds = [e for e in d["embeds"]]
self.importance = d["importance"]
self.superseded = d.get("superseded", [False] * len(d["texts"]))
# =========================
# 2. NEURAL FIELD WORLD MODEL (v5 — enhanced)
# =========================
class WorldModel(nn.Module):
def __init__(self, n_regions, state_dim, embed_dim, device):
super().__init__()
self.n = n_regions
self.d = state_dim
self.embed_dim = embed_dim
self.device = device
self.input_proj = nn.Linear(embed_dim, state_dim, bias=False)
self.to_prefix = nn.Sequential(
nn.Linear(state_dim, state_dim),
nn.Tanh(),
nn.Linear(state_dim, PREFIX_LEN * embed_dim),
)
self.W_local = nn.Parameter(
torch.randn(n_regions, n_regions, device=device) * 0.01
)
self.to(device)
# Short-term field state
self.S = torch.zeros(n_regions, state_dim, device=device)
self.strength = torch.zeros(n_regions, device=device)
self.memories = [""] * n_regions
self.thresholds = torch.ones(n_regions, device=device) * 0.5
self.protected = torch.zeros(n_regions, dtype=torch.bool, device=device)
self.step_count = 0
# v5: Long-term consolidated memory
self.M = torch.zeros(n_regions, state_dim, device=device)
# v5: Hebbian learning stats
self.hebbian_updates = 0
self.consolidation_count = 0
self.replay_count = 0
# ----- v5: Hebbian structure learning -----
def hebbian_structure_update(self, attn: torch.Tensor):
"""
Update W_local based on co-activation patterns.
Regions that fire together wire together.
Stability controls:
- Normalize outer product to prevent explosion
- Sparsify: prune weak connections
- Clamp: absolute limit on connection strength
- Decay: slow decay prevents saturation
- Zero diagonal: no self-connections
"""
with torch.no_grad():
# Co-activation: outer product of attention vector
outer = torch.outer(attn, attn)
# Normalize to unit energy per update
outer_norm = outer.norm()
if outer_norm > 1e-8:
outer = outer / outer_norm
# Apply learning rate and update
self.W_local.data += HEBBIAN_LR * outer
# Slow decay to prevent saturation over many updates
self.W_local.data *= W_LOCAL_DECAY
# Sparsify: zero out weak connections
self.W_local.data = torch.where(
self.W_local.data.abs() > HEBBIAN_PRUNE_TH,
self.W_local.data,
torch.zeros_like(self.W_local.data)
)
# Clamp to prevent runaway
self.W_local.data.clamp_(-W_LOCAL_MAX, W_LOCAL_MAX)
# Zero diagonal — no self-connections
self.W_local.data.fill_diagonal_(0.0)
self.hebbian_updates += 1
# ----- v5: Memory consolidation -----
def consolidate(self):
"""
Transfer stable short-term patterns (S) into long-term memory (M).
Only regions with sufficient strength AND protection status
consolidate. This means only taught content enters long-term
memory — noise and queries don't.
M is normalized to prevent unbounded growth.
"""
with torch.no_grad():
# Stability signal: strength weighted by protection
stability = self.strength.clone()
stability[self.protected] *= 2.0 # protected regions consolidate faster
# Only consolidate stable regions
mask = (stability > STABILITY_TH).float().unsqueeze(1)
# Exponential moving average: M slowly tracks S for stable regions
self.M = self.M + CONSOLIDATION_RATE * mask * (self.S - self.M)
# Normalize M to prevent unbounded growth
norm = self.M.norm(dim=1, keepdim=True).clamp(min=1e-8)
self.M = self.M / norm * torch.tanh(norm)
self.consolidation_count += 1
# ----- v5: Replay / dreaming -----
def replay(self, steps: int = REPLAY_STEPS) -> dict:
"""
Self-training loop without external input.
Process:
1. Add small noise for exploration
2. Run field dynamics (W_local interaction + M influence)
3. Apply attractors
4. Hebbian update on emergent activation pattern
5. Consolidate after replay
Returns diagnostics dict for UI display.
"""
with torch.no_grad():
initial_norm = self.S.norm().item()
initial_w_density = (self.W_local.data.abs() > HEBBIAN_PRUNE_TH).float().mean().item()
for _ in range(steps):
# Exploration noise
noise = torch.randn_like(self.S) * REPLAY_NOISE
self.S = self.S + noise
# Let field dynamics run
self.field_step()
# Competition sharpens activation
self.competition()
# Attractors pull toward stored patterns
self.apply_attractors()
# Hebbian update on emergent patterns
norms = self.S.norm(dim=1)
if norms.max() > 1e-8:
attn = F.softmax(norms / ATTN_TEMP, dim=0)
self.hebbian_structure_update(attn)
# Consolidate after dreaming
self.consolidate()
self.replay_count += 1
final_norm = self.S.norm().item()
final_w_density = (self.W_local.data.abs() > HEBBIAN_PRUNE_TH).float().mean().item()
return {
"steps": steps,
"norm_delta": round(final_norm - initial_norm, 4),
"w_density_delta": round(final_w_density - initial_w_density, 6),
"total_replays": self.replay_count,
}
# ----- v5: Competition -----
def competition(self):
"""
Winner-take-most: top-K regions keep full activation,
others are suppressed to a fraction. Promotes specialization.
"""
with torch.no_grad():
scores = self.S.norm(dim=1)
if scores.max() < 1e-8:
return # nothing to compete
k = min(TOPK, self.n)
topk = torch.topk(scores, k)
# Build suppression mask
mask = torch.full((self.n,), COMPETITION_SUPPRESS, device=self.device)
mask[topk.indices] = 1.0
self.S = self.S * mask.unsqueeze(1)
def field_step(self):
"""v5: Now includes long-term memory influence."""
interaction = torch.matmul(self.W_local, self.S)
gi = torch.mean(self.S, dim=0, keepdim=True) * INHIBITION_STRENGTH
# v5: Long-term memory pulls field toward consolidated patterns
memory_pull = MEMORY_INFLUENCE * self.M
self.S = torch.tanh(self.S + FIELD_RATE * (interaction - gi) + memory_pull)
def apply_attractors(self):
mask = self.strength > ATTRACTOR_TH
if mask.any():
self.S[mask] = self.S[mask] + ATTRACTOR_PULL * torch.tanh(self.S[mask])
def attend(self, x_embed):
x_embed = x_embed.to(self.S.dtype)
x_proj = self.input_proj(x_embed)
scores = torch.matmul(self.S, x_proj) - self.thresholds
scores = scores + 0.2 * self.strength
scores = scores + 0.3 * self.protected.float()
topk = torch.topk(scores, min(TOPK, self.n))
attn = torch.zeros_like(scores)
attn[topk.indices] = F.softmax(topk.values / ATTN_TEMP, dim=0)
self.thresholds += PLASTICITY_SCOPE * (attn - HOMEOSTASIS_TARGET)
self.thresholds = torch.clamp(self.thresholds, 0.1, 2.0)
return attn, x_proj
def update(self, attn, x_proj, protect: bool = False):
if protect:
top_idx = attn.topk(min(TOPK, self.n)).indices
self.protected[top_idx] = True
delta = LR * torch.outer(attn, x_proj)
self.S = self.S + delta
self.strength = (self.strength + attn * 0.05) * 0.99
decay_mask = torch.sigmoid(10 * (ATTRACTOR_TH - self.strength))
variable_decay = DECAY * (1.0 + decay_mask * 5.0)
protected_decay = 1.0 - DECAY * 0.01
base_decay = 1.0 - variable_decay
final_decay = torch.where(self.protected,
torch.full_like(base_decay, protected_decay),
base_decay)
self.S = self.S * final_decay.unsqueeze(1)
norm = self.S.norm(dim=1, keepdim=True).clamp(min=1e-8)
self.S = self.S / norm * torch.tanh(norm)
self.step_count += 1
def think(self, x_embed, steps: int = THINK_STEPS, protect: bool = False):
"""
v5: Enhanced think loop with competition + Hebbian learning.
Per step:
1. field_step() — W_local interaction + M influence
2. competition() — suppress weak activations
3. attend() — compute attention to input
4. update() — update state + decay
5. hebbian_structure_update() — learn W_local from co-activation
6. apply_attractors()
After all steps (if protect=True): consolidate S → M
"""
for _ in range(steps):
self.field_step()
self.competition()
attn, x_proj = self.attend(x_embed)
self.update(attn, x_proj, protect=protect)
self.hebbian_structure_update(attn)
self.apply_attractors()
# After teach: consolidate stable patterns into long-term memory
if protect:
self.consolidate()
return attn, x_proj
def store_memory(self, text: str, attn: torch.Tensor):
top_indices = attn.topk(min(TOPK, self.n)).indices
for idx in top_indices:
i = idx.item()
if not self.memories[i] or self.memories[i] == text:
self.memories[i] = text
return
strengths = self.strength[top_indices]
weakest = top_indices[strengths.argmin()].item()
self.memories[weakest] = text
def merge_similar(self):
with torch.no_grad():
norms = self.S.norm(dim=1, keepdim=True).clamp(min=1e-8)
S_hat = self.S / norms
sim = torch.matmul(S_hat, S_hat.T)
merged = set()
for i in range(self.n):
if i in merged: continue
for j in range(i + 1, self.n):
if j in merged: continue
if sim[i, j] > MERGE_TH:
if self.protected[i] or self.protected[j]:
continue
self.S[i] = (self.S[i] + self.S[j]) / 2.0
if self.strength[j] > self.strength[i]:
self.memories[i] = self.memories[j]
self.strength[i] = max(
self.strength[i].item(), self.strength[j].item())
self.S[j] = 0.0
self.strength[j] = 0.0
self.memories[j] = ""
merged.add(j)
return len(merged) if merged else 0
def get_focus(self, attn):
return torch.sum(attn.unsqueeze(1) * self.S, dim=0)
def get_prefix(self, S_focus, embed_dim, ref_embeds=None):
S_focus = torch.tanh(S_focus)
prefix = self.to_prefix(S_focus).view(PREFIX_LEN, embed_dim)
if ref_embeds is not None and ref_embeds.numel() > 0:
p_std = prefix.std().clamp(min=1e-8)
r_mean = ref_embeds.mean()
r_std = ref_embeds.std().clamp(min=1e-8)
prefix = (prefix - prefix.mean()) / p_std * r_std + r_mean
return prefix * PREFIX_SCALE
def recall(self, max_items: int = 3) -> list[str]:
"""Soft memory recall from strongest attractor regions."""
topk = torch.topk(self.strength, min(max_items * 2, self.n))
recalled, seen = [], set()
for idx, s in zip(topk.indices, topk.values):
i = idx.item()
if self.memories[i] and s.item() > 0.01 and self.memories[i] not in seen:
recalled.append(self.memories[i])
seen.add(self.memories[i])
if len(recalled) >= max_items:
break
return recalled
def get_cognitive_resonance(self, x_embed: torch.Tensor) -> float:
"""Measures field familiarity with the input. Returns 0.0–1.0."""
x_embed = x_embed.to(self.S.dtype)
x_proj = self.input_proj(x_embed)
raw_scores = torch.matmul(self.S, x_proj)
strength_weight = torch.clamp(self.strength, min=0.0)
weighted = raw_scores * (1.0 + strength_weight * 5.0)
protected_bonus = self.protected.float() * 2.0
weighted = weighted + protected_bonus * raw_scores.clamp(min=0.0)
k = min(TOPK, self.n)
topk_vals, _ = torch.topk(weighted, k)
energy = topk_vals.sum()
max_energy = k * (1.0 + self.strength.max().item() * 5.0 + 2.0)
max_energy = max(max_energy, 1e-8)
resonance = float(torch.sigmoid(energy / max_energy * 4.0 - 2.0).item())
return round(resonance, 4)
def diagnostics(self):
with torch.no_grad():
w_abs = self.W_local.data.abs()
w_nonzero = (w_abs > HEBBIAN_PRUNE_TH).float()
n_possible = self.n * (self.n - 1) # exclude diagonal
return {
"active_regions": (self.strength > 0.01).sum().item(),
"attractors": (self.strength > ATTRACTOR_TH).sum().item(),
"protected": self.protected.sum().item(),
"field_norm": self.S.norm().item(),
"strength_max": self.strength.max().item(),
"strongest_idx": self.strength.argmax().item(),
"total_steps": self.step_count,
"n_memories": sum(1 for m in self.memories if m),
# v5: Hebbian / consolidation / replay stats
"w_local_density": round(w_nonzero.sum().item() / max(n_possible, 1), 4),
"w_local_max": round(w_abs.max().item(), 4),
"w_local_mean": round(w_abs[w_abs > HEBBIAN_PRUNE_TH].mean().item(), 4) if (w_abs > HEBBIAN_PRUNE_TH).any() else 0.0,
"w_local_connections": int(w_nonzero.sum().item()),
"m_norm": round(self.M.norm().item(), 4),
"m_active_regions": int((self.M.norm(dim=1) > 0.01).sum().item()),
"hebbian_updates": self.hebbian_updates,
"consolidations": self.consolidation_count,
"replays": self.replay_count,
}
# =========================
# 3. COGNITIVE ROUTER (v4, unchanged)
# =========================
class CognitiveRouter:
"""
Routes queries based on field resonance + retrieval confidence.
Route matrix:
Retrieval HIGH Retrieval LOW
Resonance HIGH → CONFIDENT CAUTIOUS
Resonance LOW → UNCERTAIN DEFER
"""
def __init__(self,
resonance_high: float = RESONANCE_HIGH,
resonance_low: float = RESONANCE_LOW,
retrieval_strong: float = RETRIEVAL_STRONG,
retrieval_weak: float = RETRIEVAL_WEAK):
self.resonance_high = resonance_high
self.resonance_low = resonance_low
self.retrieval_strong = retrieval_strong
self.retrieval_weak = retrieval_weak
self.history: list[dict] = []
def route(self, resonance: float, retrieval_conf: float) -> str:
high_res = resonance >= self.resonance_high
high_ret = retrieval_conf >= self.retrieval_strong
low_ret = retrieval_conf < self.retrieval_weak
if high_res and high_ret:
route = "CONFIDENT"
elif high_res and not high_ret:
route = "CAUTIOUS"
elif not high_res and not low_ret:
route = "UNCERTAIN"
else:
route = "DEFER"
self.history.append({
"resonance": resonance,
"retrieval": retrieval_conf,
"route": route,
})
return route
def build_system_prompt(self, route: str, semantic_hits: list,
attractor_memories: list) -> str:
parts = []
if route == "CONFIDENT":
if semantic_hits:
facts = list(dict.fromkeys(t for t, _ in semantic_hits))
parts.append("Known facts (high confidence): " + " | ".join(facts))
if attractor_memories:
soft = [m for m in attractor_memories
if m not in (t for t, _ in semantic_hits)]
if soft:
parts.append("Additional context: " + " ".join(soft[:2]))
elif route == "CAUTIOUS":
if semantic_hits:
facts = list(dict.fromkeys(t for t, _ in semantic_hits))
parts.append("Possibly relevant facts (verify before stating): "
+ " | ".join(facts))
parts.append(
"Note: Your confidence in this domain is moderate. "
"If you are not sure, say so explicitly."
)
elif route == "UNCERTAIN":
if semantic_hits:
facts = list(dict.fromkeys(t for t, _ in semantic_hits))
parts.append("Retrieved facts (low domain familiarity — use cautiously): "
+ " | ".join(facts))
parts.append(
"Note: This topic may be outside your trained knowledge. "
"Verify any specific claims before presenting them as fact."
)
else: # DEFER
parts.append(
"You have no stored knowledge about this specific topic. "
"Answer only from general knowledge and clearly state that "
"you have no specific information stored about this."
)
return "\n".join(parts) if parts else ""
def get_routing_stats(self) -> dict:
if not self.history:
return {"total": 0}
from collections import Counter
routes = Counter(h["route"] for h in self.history)
return {
"total": len(self.history),
"confident": routes.get("CONFIDENT", 0),
"cautious": routes.get("CAUTIOUS", 0),
"uncertain": routes.get("UNCERTAIN", 0),
"defer": routes.get("DEFER", 0),
"avg_resonance": round(
sum(h["resonance"] for h in self.history) / len(self.history), 4),
"avg_retrieval": round(
sum(h["retrieval"] for h in self.history) / len(self.history), 4),
}
# =========================
# 4. SHARED UTILITIES
# =========================
@torch.no_grad()
def mean_pool_text(tokenizer, model, text: str, device) -> torch.Tensor:
ids = tokenizer(text, return_tensors="pt",
truncation=True, max_length=256)["input_ids"].to(device)
embeds = model.get_input_embeddings()(ids).squeeze(0)
return embeds.mean(dim=0)
def extract_answer(raw_text: str) -> str:
marker = "assistant\n"
idx = raw_text.rfind(marker)
return raw_text[idx + len(marker):].strip() if idx != -1 else raw_text.strip()
# =========================
# 5. HybridLLM — MAIN AGENT (v5)
# =========================
class HybridLLM:
def __init__(self):
print(f"Loading model on {'CUDA' if torch.cuda.is_available() else 'CPU'}...")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
try:
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager",
)
except ValueError as ve:
print("[ERROR] Failed to use `device_map` when loading model:", ve)
print("Attempting CPU-only fallback load (this may be slow).")
try:
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map=None,
low_cpu_mem_usage=True,
)
self.model.eval()
self.embed_dim = self.model.get_input_embeddings().weight.shape[1]
self.device = next(self.model.parameters()).device
print(f" Fallback model loaded -> device={self.device} dtype={next(self.model.parameters()).dtype}")
except Exception as e2:
print("[ERROR] CPU fallback also failed:", e2)
raise
self.model.eval()
self.embed_dim = self.model.get_input_embeddings().weight.shape[1]
self.device = next(self.model.parameters()).device
print(f" Model loaded -> device={self.device} dtype={next(self.model.parameters()).dtype}")
self.world = WorldModel(N_REGIONS, STATE_DIM, self.embed_dim, self.device)
self.episodes = PinnedEpisodicStore(max_episodes=MAX_EPISODES)
self.router = CognitiveRouter()
self.call_count = 0
self.token_stats = {
"teach_input_tokens": 0,
"generate_input_tokens": 0,
"generate_output_tokens": 0,
"generate_calls": 0,
"teach_calls": 0,
}
def embed_token(self, token_id: int) -> torch.Tensor:
t = torch.tensor([token_id], device=self.device)
return self.model.get_input_embeddings()(t).squeeze(0)
@torch.no_grad()
def mean_pool_text(self, text: str) -> torch.Tensor:
return mean_pool_text(self.tokenizer, self.model, text, self.device)
def _process_input(self, text: str, protect: bool = False,
think_steps: int = THINK_STEPS):
raw_ids = self.tokenizer(text, return_tensors="pt",
truncation=True, max_length=512
)["input_ids"].to(self.device)[0]
x_embed = None
for token_id in raw_ids:
x_embed = self.embed_token(int(token_id))
attn, x_proj = self.world.attend(x_embed)
self.world.update(attn, x_proj, protect=protect)
attn, x_proj = self.world.think(x_embed, steps=think_steps, protect=protect)
self.world.store_memory(text, attn)
S_focus = self.world.get_focus(attn)
self.call_count += 1
if self.call_count % MERGE_EVERY == 0:
self.world.merge_similar()
return attn, S_focus, x_embed
def teach(self, text: str, verbose: bool = False, auto_dream: bool = True):
n_tok = len(self.tokenizer(text)["input_ids"])
self.token_stats["teach_input_tokens"] += n_tok
self.token_stats["teach_calls"] += 1
before_sup = set(self.episodes.texts[i]
for i, s in enumerate(self.episodes.superseded) if s)
attn, S_focus, _ = self._process_input(
text, protect=True, think_steps=TEACH_THINK_STEPS)
mean_embed = self.mean_pool_text(text)
self.episodes.pin(text, mean_embed)
after_sup = set(self.episodes.texts[i]
for i, s in enumerate(self.episodes.superseded) if s)
newly_superseded = after_sup - before_sup
if newly_superseded:
for i, mem in enumerate(self.world.memories):
if mem in newly_superseded:
self.world.memories[i] = ""
# v5: Auto-dream after teaching to reinforce structure
dream_result = None
if auto_dream:
dream_result = self.world.replay(steps=REPLAY_STEPS)
if verbose:
d = self.world.diagnostics()
ed = self.episodes.diagnostics()
print(f" [Teach] norm={d['field_norm']:.2f} "
f"protected={d['protected']}/{N_REGIONS} "
f"episodes={ed['total']} (active={ed['active']} "
f"superseded={ed['superseded']})"
+ (f" cleared={newly_superseded}" if newly_superseded else ""))
if dream_result:
print(f" [Dream] steps={dream_result['steps']} "
f"norm_delta={dream_result['norm_delta']:.4f} "
f"w_density_delta={dream_result['w_density_delta']:.6f}")
return dream_result
def dream(self, cycles: int = 1, steps_per_cycle: int = REPLAY_STEPS,
verbose: bool = False) -> list[dict]:
"""
v5: Explicit dream session. Can be called from Gradio UI.
Returns list of per-cycle diagnostics.
"""
cycles = min(cycles, MAX_REPLAY_CYCLES)
results = []
for i in range(cycles):
result = self.world.replay(steps=steps_per_cycle)
results.append(result)
if verbose:
print(f" [Dream cycle {i+1}/{cycles}] "
f"norm_delta={result['norm_delta']:.4f} "
f"w_density_delta={result['w_density_delta']:.6f}")
return results
def generate(self, text: str, max_new_tokens: int = 80,
verbose: bool = False) -> str:
resp = self.generate_cognitive(text, max_new_tokens=max_new_tokens,
verbose=verbose)
return resp.text
def generate_cognitive(self, text: str, max_new_tokens: int = 80,
verbose: bool = False) -> CognitiveResponse:
t0 = time.perf_counter()
attn, S_focus, x_embed = self._process_input(text, protect=False)
query_embed = self.mean_pool_text(text)
resonance = self.world.get_cognitive_resonance(x_embed)
retrieval_conf = self.episodes.retrieval_confidence(query_embed)
route = self.router.route(resonance, retrieval_conf)
semantic_hits = self.episodes.retrieve(
query_embed, topk=SEMANTIC_TOPK, threshold=SEMANTIC_THRESHOLD)
superseded_texts = {
self.episodes.texts[i]
for i, sup in enumerate(self.episodes.superseded) if sup
}
attractor_memories = [m for m in self.world.recall()
if m not in superseded_texts]
system_content = self.router.build_system_prompt(
route, semantic_hits, attractor_memories)
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": text})
if verbose:
d = self.world.diagnostics()
print(f" [Route] {route} resonance={resonance:.4f} "
f"retrieval={retrieval_conf:.4f}")
print(f" [Field] norm={d['field_norm']:.2f} "
f"active={d['active_regions']}/{N_REGIONS} "
f"attractors={d['attractors']} "
f"protected={d['protected']}")
print(f" [v5] W_local density={d['w_local_density']:.4f} "
f"M_norm={d['m_norm']:.4f} "
f"hebbian_updates={d['hebbian_updates']}")
print(f" [Episodes] total={len(self.episodes)}")
if semantic_hits:
for t, s in semantic_hits:
print(f" hit(score={s:.3f}): {t[:70]}")
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
input_embeds = self.model.get_input_embeddings()(prompt_ids)
if route in ("CONFIDENT", "CAUTIOUS"):
prefix = self.world.get_prefix(S_focus, self.embed_dim,
ref_embeds=input_embeds.squeeze(0))
prefix_emb = prefix.unsqueeze(0).to(input_embeds.dtype)
full_embeds = torch.cat([prefix_emb, input_embeds], dim=1)
else:
full_embeds = input_embeds
attn_mask = torch.ones(full_embeds.shape[:2],
device=self.device, dtype=torch.long)
output = self.model.generate(
inputs_embeds=full_embeds,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
n_in = prompt_ids.shape[1] + (PREFIX_LEN if route in ("CONFIDENT", "CAUTIOUS") else 0)
n_out = output.shape[1] - n_in
self.token_stats["generate_input_tokens"] += n_in
self.token_stats["generate_output_tokens"] += max(0, n_out)
self.token_stats["generate_calls"] += 1
raw_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
elapsed = time.perf_counter() - t0
return CognitiveResponse(
text=raw_text,
route=route,
resonance=resonance,
retrieval_confidence=retrieval_conf,
retrieved_facts=semantic_hits,
field_diagnostics=self.world.diagnostics(),
latency_s=round(elapsed, 3),
)
def token_report(self) -> dict:
s = self.token_stats
total_in = s["teach_input_tokens"] + s["generate_input_tokens"]
total_out = s["generate_output_tokens"]
return {
**s,
"total_input_tokens": total_in,
"total_output_tokens": total_out,
"avg_input_per_gen": round(s["generate_input_tokens"] / max(1, s["generate_calls"]), 1),
"avg_output_per_gen": round(s["generate_output_tokens"] / max(1, s["generate_calls"]), 1),
}
def reset_world(self):
self.world.S.zero_()
self.world.strength.zero_()
self.world.step_count = 0
self.world.memories = [""] * self.world.n
self.world.thresholds.fill_(0.5)
self.world.protected.fill_(False)
# v5: reset new state
self.world.M.zero_()
self.world.W_local.data = torch.randn_like(self.world.W_local.data) * 0.01
self.world.W_local.data.fill_diagonal_(0.0)
self.world.hebbian_updates = 0
self.world.consolidation_count = 0
self.world.replay_count = 0
self.episodes.clear()
self.router.history.clear()
self.call_count = 0
for k in self.token_stats:
self.token_stats[k] = 0
def save_world(self, path: str = "world.pt"):
torch.save({
"S": self.world.S,
"strength": self.world.strength,
"W_local": self.world.W_local,
"step_count": self.world.step_count,
"memories": self.world.memories,
"protected": self.world.protected,
"thresholds": self.world.thresholds,
"episodes": self.episodes.state_dict(),
# v5 additions
"M": self.world.M,
"hebbian_updates": self.world.hebbian_updates,
"consolidation_count": self.world.consolidation_count,
"replay_count": self.world.replay_count,
"version": 5,
}, path)
print(f" [Saved world -> {path}] episodes={len(self.episodes)} version=5")
def load_world(self, path: str = "world.pt"):
if os.path.exists(path):
data = torch.load(path, map_location=self.device, weights_only=False)
self.world.S = data["S"].to(self.device)
self.world.strength = data["strength"].to(self.device)
if "W_local" in data:
self.world.W_local = nn.Parameter(data["W_local"].to(self.device))
if "step_count" in data:
self.world.step_count = data["step_count"]
if "memories" in data:
self.world.memories = data["memories"]
if "protected" in data:
self.world.protected = data["protected"].to(self.device)
if "thresholds" in data:
self.world.thresholds = data["thresholds"].to(self.device)
if "episodes" in data:
self.episodes.load_state_dict(data["episodes"])
self.episodes.embeds = [
e.to(self.device) for e in self.episodes.embeds]
# v5: load new state (backward compatible)
if "M" in data:
self.world.M = data["M"].to(self.device)
else:
print(" [v5 upgrade: M initialized to zero]")
if "hebbian_updates" in data:
self.world.hebbian_updates = data["hebbian_updates"]
if "consolidation_count" in data:
self.world.consolidation_count = data["consolidation_count"]
if "replay_count" in data:
self.world.replay_count = data["replay_count"]
ver = data.get("version", 4)
print(f" [Loaded world <- {path}] episodes={len(self.episodes)} version={ver}")
else:
print(" [No saved world found]")
# =========================
# 6. RAG BASELINE (unchanged)
# =========================
class RAGBaseline:
"""Pure semantic retrieval baseline. No Hebbian field, no soft prefix."""
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
self.device = next(model.parameters()).device
self.episodes = PinnedEpisodicStore(max_episodes=MAX_EPISODES)
@torch.no_grad()
def _mean_pool(self, text: str) -> torch.Tensor:
return mean_pool_text(self.tokenizer, self.model, text, self.device)
def reset_world(self):
self.episodes.clear()
def teach(self, text: str, verbose: bool = False):
embed = self._mean_pool(text)
self.episodes.pin(text, embed)
if verbose:
print(f" [RAG teach] pinned: '{text[:60]}' total={len(self.episodes)}")
def generate(self, text: str, max_new_tokens: int = 80,
verbose: bool = False) -> str:
query_embed = self._mean_pool(text)
hits = self.episodes.retrieve(
query_embed, topk=SEMANTIC_TOPK, threshold=SEMANTIC_THRESHOLD)
messages = []
if hits:
facts = list(dict.fromkeys(t for t, _ in hits))
messages.append({"role": "system",
"content": "Known facts: " + " | ".join(facts)})
messages.append({"role": "user", "content": text})
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
attn_mask = torch.ones_like(ids)
with torch.no_grad():
out = self.model.generate(
ids, attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
do_sample=True, temperature=0.7, top_p=0.9)
return self.tokenizer.decode(out[0], skip_special_tokens=True)
# =========================
# 7. FIELD ONLY — ABLATION (unchanged)
# =========================
class FieldOnly:
"""Ablation: Hebbian Neural Field + Soft Prefix ONLY. No episodic store."""
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
self.device = next(model.parameters()).device
self.embed_dim = model.get_input_embeddings().weight.shape[1]
self.world = WorldModel(N_REGIONS, STATE_DIM, self.embed_dim, self.device)
self.call_count = 0
def reset_world(self):
self.world.S.zero_()
self.world.strength.zero_()
self.world.step_count = 0
self.world.memories = [""] * self.world.n
self.world.thresholds.fill_(0.5)
self.world.protected.fill_(False)
self.world.M.zero_()
self.world.W_local.data = torch.randn_like(self.world.W_local.data) * 0.01
self.world.W_local.data.fill_diagonal_(0.0)
self.call_count = 0
def _embed_token(self, token_id: int) -> torch.Tensor:
t = torch.tensor([token_id], device=self.device)
return self.model.get_input_embeddings()(t).squeeze(0)
def _process_input(self, text: str, protect: bool = False,
think_steps: int = THINK_STEPS):
raw_ids = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
)["input_ids"].to(self.device)[0]
x_embed = None
for token_id in raw_ids:
x_embed = self._embed_token(int(token_id))
attn, x_proj = self.world.attend(x_embed)
self.world.update(attn, x_proj, protect=protect)
attn, x_proj = self.world.think(x_embed, steps=think_steps, protect=protect)
self.world.store_memory(text, attn)
S_focus = self.world.get_focus(attn)
self.call_count += 1
if self.call_count % MERGE_EVERY == 0:
self.world.merge_similar()
return attn, S_focus
def teach(self, text: str, verbose: bool = False):
self._process_input(text, protect=True, think_steps=TEACH_THINK_STEPS)
if verbose:
d = self.world.diagnostics()
print(f" [FieldOnly teach] norm={d['field_norm']:.2f} "
f"protected={d['protected']}/{N_REGIONS}")
def generate(self, text: str, max_new_tokens: int = 80,
verbose: bool = False) -> str:
attn, S_focus = self._process_input(text, protect=False)
messages = [{"role": "user", "content": text}]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
input_embeds = self.model.get_input_embeddings()(prompt_ids)
prefix = self.world.get_prefix(
S_focus, self.embed_dim, ref_embeds=input_embeds.squeeze(0))
prefix_emb = prefix.unsqueeze(0).to(input_embeds.dtype)
full_embeds = torch.cat([prefix_emb, input_embeds], dim=1)
attn_mask = torch.ones(
full_embeds.shape[:2], device=self.device, dtype=torch.long)
with torch.no_grad():
output = self.model.generate(
inputs_embeds=full_embeds,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
# =========================
# 8. CONTEXT WINDOW BASELINE (unchanged)
# =========================
class ContextBaseline:
"""Naive context-stuffing baseline."""
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
self.device = next(model.parameters()).device
self._history: list[dict] = []
def reset_world(self):
self._history.clear()
def teach(self, text: str, verbose: bool = False):
self._history.append({"role": "user", "content": text})
self._history.append({"role": "assistant", "content": "(acknowledged)"})
def generate(self, text: str, max_new_tokens: int = 80,
verbose: bool = False) -> str:
messages = list(self._history) + [{"role": "user", "content": text}]
while len(messages) > 2:
prompt_check = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
if len(self.tokenizer(prompt_check)["input_ids"]) < 1800:
break
messages = messages[2:]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.device)
attn_mask = torch.ones_like(ids)
with torch.no_grad():
out = self.model.generate(
ids, attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
do_sample=True, temperature=0.7, top_p=0.9)
decoded = self.tokenizer.decode(out[0], skip_special_tokens=True)
self._history.append({"role": "user", "content": text})
self._history.append({"role": "assistant", "content": decoded[-200:]})
if len(self._history) > 20:
self._history = self._history[-20:]
return decoded
# =========================
# MAIN LOOP (v5)
# =========================
def main():
agent = HybridLLM()
print()
print("=" * 60)
print(" Proto-Cognitive Hybrid LLM v5")
print(" Neural Field | Episodic Memory | Cognitive Router")
print(" + Hebbian Learning | Consolidation | Replay")
print("=" * 60)
print(" Commands: save | load | clear | stats | episodes | route")
print(" dream | dream:N | exit")
print(" Prefix teach: to encode a fact: teach: <fact text>")
print()
while True:
try:
raw = input("YOU: ").strip()
except EOFError:
break
if not raw:
continue
text = raw
if text.lower() == "exit":
break
elif text.lower() == "save":
agent.save_world()
elif text.lower() == "load":
agent.load_world()
elif text.lower() == "clear":
agent.reset_world()
print(" [World cleared]")
elif text.lower() == "route":
rs = agent.router.get_routing_stats()
print(f" Routing stats: {rs}")
elif text.lower().startswith("dream"):
parts = text.split(":")
cycles = int(parts[1]) if len(parts) > 1 and parts[1].strip().isdigit() else 1
results = agent.dream(cycles=cycles, verbose=True)
d = agent.world.diagnostics()
print(f" [Post-dream] W_local density={d['w_local_density']:.4f} "
f"M_norm={d['m_norm']:.4f} "
f"connections={d['w_local_connections']}")
elif text.lower() == "stats":
d = agent.world.diagnostics()
print(f" Field norm : {d['field_norm']:.4f}")
print(f" Active regions : {d['active_regions']}/{N_REGIONS}")
print(f" Attractors : {d['attractors']}")
print(f" Protected : {d['protected']}")
print(f" Memories : {d['n_memories']}")
print(f" Strongest : idx={d['strongest_idx']} ({d['strength_max']:.4f})")
print(f" Total steps : {d['total_steps']}")
print(f" --- v5 ---")
print(f" W_local density: {d['w_local_density']:.4f} ({d['w_local_connections']} connections)")
print(f" W_local max : {d['w_local_max']:.4f} mean={d['w_local_mean']:.4f}")
print(f" M norm : {d['m_norm']:.4f} ({d['m_active_regions']} active regions)")
print(f" Hebbian updates: {d['hebbian_updates']}")
print(f" Consolidations : {d['consolidations']}")
print(f" Replay cycles : {d['replays']}")
ed = agent.episodes.diagnostics()
print(f" Pinned episodes: {ed['total']} (active={ed['active']} superseded={ed['superseded']})")
rs = agent.router.get_routing_stats()
print(f" Router stats : {rs}")
elif text.lower() == "episodes":
print(f" Pinned episodes ({len(agent.episodes)}):")
for i, (t, imp) in enumerate(
zip(agent.episodes.texts, agent.episodes.importance)):
print(f" [{i:3d}] imp={imp:.1f} {t[:80]}")
elif text.lower().startswith("teach:"):
fact = text[6:].strip()
if fact:
agent.teach(fact, verbose=True)
print(f" [Encoded + pinned: '{fact[:60]}']")
else:
resp = agent.generate_cognitive(text, verbose=True)
answer = extract_answer(resp.text)
print(f"\n [{resp.route}] resonance={resp.resonance:.3f} "
f"retrieval={resp.retrieval_confidence:.3f}")
print(f"\nAI: {answer}")
print("-" * 60)
if __name__ == "__main__":
main()