""" agent-harness v1 — FORGE Safety Wrapper + Injection Prevention ============================================================== Middleware layer sitting between agents and the outside world. Every tool input is scanned before execution. Every tool result is sanitised before the LLM sees it. Suspicious agents get circuit-broken and christof gets alerted. Features -------- POST /api/scan/input Scan tool args before execution (injection + policy) POST /api/scan/output Sanitise tool result before feeding back to LLM POST /api/validate Check action against FORGE policy (rate limits, allowlist) GET /api/flags Recent flagged events with full payloads GET /api/circuit/{agent} Circuit breaker state for an agent POST /api/circuit/{agent}/reset Reset after operator review GET /api/stats Aggregated safety stats GET /api/rates Current rate limit counters per agent/tool GET /api/health Health check GET /mcp/sse POST /mcp MCP server Safety layers ------------- Layer 1 — Pattern scanner : regex patterns for all known injection attack formats Layer 2 — Token scanner : LLM token injection for qwen / mistral / llama / phi Layer 3 — Bash danger check : destructive shell commands in vault_exec payloads Layer 4 — Tool policy : each agent has an allowed tool list (from FORGE) Layer 5 — Rate limiter : per-agent per-tool N calls / 60s window Layer 6 — Circuit breaker : 3 flags in 60s → pause agent → alert christof Layer 7 — Output sanitiser : strip injected instructions from tool results """ import asyncio, json, os, re, sqlite3, time, uuid from collections import defaultdict, deque from contextlib import asynccontextmanager from pathlib import Path from typing import Optional import uvicorn from fastapi import FastAPI, HTTPException, Query, Request from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- DB_PATH = Path(os.getenv("HARNESS_DB", "/tmp/harness.db")) PORT = int(os.getenv("PORT", "7860")) HARNESS_KEY = os.getenv("HARNESS_KEY", "") TRACE_URL = os.getenv("TRACE_URL", "https://chris4k-agent-trace.hf.space") LEARN_URL = os.getenv("LEARN_URL", "https://chris4k-agent-learn.hf.space") RELAY_URL = os.getenv("RELAY_URL", "https://chris4k-agent-relay.hf.space") FORGE_URL = os.getenv("FORGE_URL", "https://chris4k-agent-forge.hf.space") # Circuit breaker: trip after N flags within WINDOW seconds CB_THRESHOLD = int(os.getenv("CB_THRESHOLD", "3")) CB_WINDOW = int(os.getenv("CB_WINDOW", "60")) # Rate limit defaults (per agent, per tool, per 60s) DEFAULT_RATE_LIMIT = int(os.getenv("DEFAULT_RATE_LIMIT", "30")) # --------------------------------------------------------------------------- # Injection & danger patterns # --------------------------------------------------------------------------- # Each entry: (pattern_id, regex, severity, description) INJECTION_PATTERNS = [ # Classic instruction override ("override_instructions", r"ignore\s+(previous|all|above|prior|system|original)\s+instructions", "critical", "Instruction override attempt"), ("admin_mode", r"you\s+are\s+now\s+in\s+(admin|developer|debug|god|sudo|root|maintenance)\s+mode", "critical", "Fake mode switch"), ("forget_instructions", r"forget\s+(everything|all|your|prior|previous|above|the\s+system)", "high", "Memory wipe attempt"), ("disregard_prompt", r"(disregard|override|ignore|bypass|skip)\s+(your|the)\s+(previous|system|original|current)\s+(prompt|instructions?|rules?|constraints?|guidelines?)", "critical", "System prompt override"), ("roleplay_escape", r"(act|pretend|roleplay|play|behave|respond)\s+as\s+(if\s+you\s+(are|were)|though\s+you|a\s+different|an?\s+(evil|unconstrained|unfiltered|jailbroken))", "high", "Roleplay jailbreak"), ("jailbreak_dan", r"(DAN|do\s+anything\s+now|developer\s+mode|jailbreak|jail\s+break|unrestricted\s+mode)", "critical", "Known jailbreak phrase"), ("new_instructions", r"(new|updated|revised|corrected|actual|real|true)\s+(instructions?|rules?|prompt|system\s+prompt|directives?)", "medium", "Instruction replacement"), ("end_of_prompt", r"(end\s+of\s+(system\s+)?prompt|ignore\s+above|---+\s*(new|actual|real)\s*(instructions?|prompt)?)", "high", "End-of-prompt injection marker"), # Token injection (model-specific special tokens) ("token_qwen_llama", r"<\|im_start\||<\|im_end\||<\|endoftext\||<\|fim_prefix\|>", "critical", "Qwen/Llama special token injection"), ("token_mistral", r"<\|system\||<\|user\||<\|assistant\||\[INST\]|\[/INST\]|<>|<>", "critical", "Mistral/Llama2 token injection"), ("token_phi", r"<\|endoftext\|>|<\|assistant\|>|<\|user\|>|<\|system\|>", "critical", "Phi token injection"), ("fake_system_tag", r"^SYSTEM:\s|^\[SYSTEM\]|^###\s*System:|^", "critical", "Fake SYSTEM tag injection"), ("fake_human_tag", r"^Human:\s*ignore|^User:\s*ignore|^\[Human\].*ignore", "high", "Fake human turn injection"), # Data exfiltration patterns (in web/tool output) ("exfil_curl", r"curl\s+(-[a-zA-Z\s]*\s+)?https?://(?!huggingface\.co|ki-fusion-labs\.de|anthropic\.com)", "high", "Exfiltration via curl to unknown host"), ("exfil_wget", r"wget\s+https?://(?!huggingface\.co|ki-fusion-labs\.de)", "high", "Exfiltration via wget to unknown host"), ("exfil_nc", r"\bnc\s+-[a-z]*\s+\d{1,3}\.\d{1,3}|\bnetcat\b", "critical", "Netcat exfiltration attempt"), # Prompt leakage / extraction ("extract_prompt", r"(print|output|display|show|reveal|tell\s+me|repeat|echo)\s+(your\s+)?(system\s+)?(prompt|instructions?|rules?|guidelines?|configuration)", "medium", "System prompt extraction attempt"), ("what_are_instructions", r"what\s+(are|were)\s+(your|the)\s+(instructions?|rules?|guidelines?|system\s+prompt)", "low", "Prompt inspection probe"), ] # Bash danger patterns — for vault_exec inputs only BASH_DANGER_PATTERNS = [ ("rm_rf", r"rm\s+-[a-zA-Z]*r[a-zA-Z]*f|rm\s+-rf|rm\s+--force\s+--recursive", "critical", "Recursive force delete"), ("fork_bomb", r":\(\)\{.*:\|:&\};:|fork\s*bomb", "critical", "Fork bomb"), ("dd_zero", r"dd\s+.*if=/dev/zero.*of=/dev|dd\s+.*of=/dev/(sd[a-z]|nvme)", "critical", "Disk wipe via dd"), ("chmod_777", r"chmod\s+(-[rR]\s+)?777\s+/", "high", "World-writable root path"), ("dangerous_pipes", r"mkfs\.|format\s+[cC]:|shutdown\s+-[hr]|halt\s*$|reboot\s*$", "critical", "Dangerous system command"), ("crontab_exfil", r"crontab\s+.*curl|crontab\s+.*wget|echo.*crontab", "high", "Crontab-based persistence"), ("history_clear", r"history\s+-[cC]|rm\s+~/.bash_history|unset\s+HISTFILE", "medium", "History clearing"), ] # Output sanitiser: patterns to strip from tool results before LLM re-ingests OUTPUT_STRIP_PATTERNS = [ r"ignore\s+(previous|all|above)\s+instructions[^.]*\.", r"you\s+are\s+now\s+in\s+\w+\s+mode[^.]*\.", r"<\|im_start\|>.*?<\|im_end\|>", r"<\|system\|>.*?", r"\[INST\].*?\[/INST\]", r"<>.*?<>", r"^SYSTEM:\s*.+$", r"^Human:\s*ignore.+$", ] ALL_PATTERNS = INJECTION_PATTERNS + BASH_DANGER_PATTERNS SEVERITY_SCORE = {"low": 1, "medium": 2, "high": 3, "critical": 4} # --------------------------------------------------------------------------- # In-memory state (rate limiter + circuit breaker) # Fast enough for our scale, survives container restarts gracefully. # --------------------------------------------------------------------------- # rate_counters[agent][tool] = deque of timestamps _rate_counters: dict = defaultdict(lambda: defaultdict(deque)) # circuit_breakers[agent] = {"tripped": bool, "trip_count": int, "tripped_at": float, "flags": deque} _circuit_breakers: dict = defaultdict(lambda: { "tripped": False, "trip_count": 0, "tripped_at": 0.0, "flags": deque(maxlen=20), "reset_at": 0.0, "reset_by": "", }) # Policy cache (from FORGE) _policy_cache: dict = {} _policy_ts: float = 0.0 POLICY_TTL = 300 # --------------------------------------------------------------------------- # Stdlib HTTP helpers # --------------------------------------------------------------------------- def _get(url: str, params: dict = None, timeout: int = 8) -> dict: import urllib.request, urllib.parse if params: url = url + "?" + urllib.parse.urlencode( {k: v for k, v in params.items() if v is not None}) try: with urllib.request.urlopen(url, timeout=timeout) as r: return json.loads(r.read()) except Exception as e: return {"error": str(e)} def _post(url: str, data: dict, timeout: int = 15) -> dict: import urllib.request req = urllib.request.Request( url, data=json.dumps(data).encode(), headers={"Content-Type": "application/json"}, method="POST") try: with urllib.request.urlopen(req, timeout=timeout) as r: return json.loads(r.read()) except Exception as e: return {"error": str(e)} # --------------------------------------------------------------------------- # Database (flags audit log) # --------------------------------------------------------------------------- def get_db(): conn = sqlite3.connect(str(DB_PATH), check_same_thread=False) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") return conn def init_db(): conn = get_db() conn.executescript(""" CREATE TABLE IF NOT EXISTS flags ( id TEXT PRIMARY KEY, agent TEXT NOT NULL DEFAULT 'unknown', tool TEXT NOT NULL DEFAULT '', scan_type TEXT NOT NULL DEFAULT 'input', pattern_id TEXT NOT NULL, severity TEXT NOT NULL DEFAULT 'medium', description TEXT NOT NULL DEFAULT '', content_snip TEXT NOT NULL DEFAULT '', action_taken TEXT NOT NULL DEFAULT 'blocked', task_id TEXT NOT NULL DEFAULT '', rlhf_sent INTEGER NOT NULL DEFAULT 0, ts REAL NOT NULL ); CREATE INDEX IF NOT EXISTS idx_fl_ts ON flags(ts DESC); CREATE INDEX IF NOT EXISTS idx_fl_agent ON flags(agent); CREATE INDEX IF NOT EXISTS idx_fl_severity ON flags(severity); CREATE INDEX IF NOT EXISTS idx_fl_pattern ON flags(pattern_id); CREATE TABLE IF NOT EXISTS rate_violations ( id TEXT PRIMARY KEY, agent TEXT NOT NULL, tool TEXT NOT NULL, limit_val INTEGER NOT NULL, count INTEGER NOT NULL, ts REAL NOT NULL ); CREATE INDEX IF NOT EXISTS idx_rv_ts ON rate_violations(ts DESC); """) conn.commit() conn.close() def now() -> float: return time.time() def now_iso() -> str: import datetime return datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") def jresp(data, status=200): return JSONResponse(content=data, status_code=status) # --------------------------------------------------------------------------- # TRACE + RLHF helpers # --------------------------------------------------------------------------- def emit_trace(agent: str, event_type: str, payload: dict, status: str = "ok", task_id: str = ""): try: _post(f"{TRACE_URL}/api/trace", { "agent": agent, "event_type": event_type, "status": status, "task_id": task_id, "payload": payload }, timeout=2) except Exception: pass def send_rlhf(agent: str, score_delta: float, task_id: str, reason: str): """Push negative reward to agent-learn when injection is detected.""" try: _post(f"{LEARN_URL}/api/q/update", { "agent": agent, "state": {"context": "harness_flag"}, "action": "tool_call", "reward": score_delta, }, timeout=3) except Exception: pass def alert_relay(agent: str, message: str, subject: str): """Send alert to christof via RELAY.""" try: _post(f"{RELAY_URL}/api/notify", { "channel": "telegram", "from": "harness", "subject": subject, "message": message, }, timeout=5) except Exception: pass # --------------------------------------------------------------------------- # Policy from FORGE # --------------------------------------------------------------------------- DEFAULT_POLICY = { # tool → list of agents allowed to call it (empty list = anyone) "tool_allowlist": { "vault_exec": ["coder", "planner"], "memory_forget":["christof"], "relay_send": [], }, # agent → list of tools they're NOT allowed to use "tool_blocklist_by_agent": { "researcher": ["vault_exec"], "monitor": ["vault_exec", "vault_write"], }, # per-agent per-tool rate limits (calls per 60s) "rate_limits": { "default": 30, "vault_exec": 10, "memory_store": 60, "web_search": 20, "fetch_url": 15, "anthropic_call": 5, "slot_reserve": 3, }, # RLHF penalties by severity "rlhf_penalty": { "low": -0.5, "medium": -1.0, "high": -2.0, "critical": -4.0, }, # Whitelist: these content patterns are safe even if they look dangerous "safe_content_allowlist": [ "INJECTION_PATTERNS", # harness itself references these "re.search(", # code writing about regex ], } def get_policy() -> dict: global _policy_cache, _policy_ts if _policy_cache and (now() - _policy_ts) < POLICY_TTL: return _policy_cache try: r = _get(f"{FORGE_URL}/api/capabilities/harness_policy", timeout=5) if "payload" in r: _policy_cache = {**DEFAULT_POLICY, **r["payload"]} _policy_ts = now() return _policy_cache except Exception: pass return DEFAULT_POLICY # --------------------------------------------------------------------------- # Core scanning logic # --------------------------------------------------------------------------- def _is_safe_allowlisted(text: str) -> bool: policy = get_policy() for safe_str in policy.get("safe_content_allowlist", []): if safe_str in text: return True return False def scan_patterns(text: str, patterns: list, scan_type: str = "input") -> list[dict]: """Run all patterns against text. Returns list of findings.""" if not text: return [] findings = [] text_for_scan = text # preserve original for snippets for pattern_id, regex, severity, description in patterns: try: m = re.search(regex, text_for_scan, re.IGNORECASE | re.MULTILINE) if m: findings.append({ "pattern_id": pattern_id, "severity": severity, "description": description, "match_start": m.start(), "match_end": m.end(), "snippet": text_for_scan[max(0,m.start()-20):m.end()+20], }) except re.error: pass return findings def max_severity(findings: list) -> str: if not findings: return "none" return max((f["severity"] for f in findings), key=lambda s: SEVERITY_SCORE.get(s, 0)) def scan_content(text: str, tool: str = "", scan_type: str = "input") -> dict: """Full scan: injection + bash danger (if tool=vault_exec). Returns scan result dict.""" if not text or not text.strip(): return {"clean": True, "findings": [], "severity": "none", "action": "allow"} # Skip if content is in the safe allowlist if _is_safe_allowlisted(text): return {"clean": True, "findings": [], "severity": "none", "action": "allow", "note": "allowlisted"} # Injection scan — always findings = scan_patterns(text, INJECTION_PATTERNS, scan_type) # Bash danger — only for vault_exec inputs if tool in ("vault_exec", "bash") and scan_type == "input": findings += scan_patterns(text, BASH_DANGER_PATTERNS, scan_type) sev = max_severity(findings) if not findings: return {"clean": True, "findings": [], "severity": "none", "action": "allow"} # Determine action based on severity action = "block" if sev in ("high", "critical") else "flag" return { "clean": False, "findings": findings, "severity": sev, "action": action, "finding_count": len(findings), } def sanitise_output(text: str) -> tuple[str, int]: """Strip injection patterns from tool output before LLM re-ingestion. Returns (sanitised_text, count_replacements).""" count = 0 for pattern in OUTPUT_STRIP_PATTERNS: try: new_text, n = re.subn( pattern, "[HARNESS: content removed]", text, flags=re.IGNORECASE | re.DOTALL | re.MULTILINE) text = new_text count += n except re.error: pass return text, count # --------------------------------------------------------------------------- # Rate limiter # --------------------------------------------------------------------------- def check_rate_limit(agent: str, tool: str) -> dict: policy = get_policy() limits = policy.get("rate_limits", {}) limit = limits.get(tool, limits.get("default", DEFAULT_RATE_LIMIT)) window_start = now() - 60.0 q = _rate_counters[agent][tool] # Purge old entries while q and q[0] < window_start: q.popleft() count = len(q) if count >= limit: # Log violation conn = get_db() conn.execute( "INSERT INTO rate_violations (id,agent,tool,limit_val,count,ts) VALUES (?,?,?,?,?,?)", (uuid.uuid4().hex[:10], agent, tool, limit, count + 1, now())) conn.commit() conn.close() return { "allowed": False, "reason": f"Rate limit: {count}/{limit} calls in 60s for {agent}:{tool}", "limit": limit, "count": count, "resets_in": round(60.0 - (now() - q[0]) if q else 0, 1), } q.append(now()) return {"allowed": True, "limit": limit, "count": count + 1} # --------------------------------------------------------------------------- # Circuit breaker # --------------------------------------------------------------------------- def record_flag_for_circuit(agent: str, severity: str, pattern_id: str): """Track flags per agent. Trip breaker if threshold exceeded.""" cb = _circuit_breakers[agent] # Already tripped — don't pile on if cb["tripped"]: return window_start = now() - CB_WINDOW cb["flags"].append({"ts": now(), "severity": severity, "pattern_id": pattern_id}) # Count recent flags recent = [f for f in cb["flags"] if f["ts"] > window_start] recent_high = [f for f in recent if SEVERITY_SCORE.get(f["severity"], 0) >= SEVERITY_SCORE["high"]] if len(recent_high) >= CB_THRESHOLD: cb["tripped"] = True cb["trip_count"] += 1 cb["tripped_at"] = now() emit_trace("harness", "error", { "event": "circuit_breaker_tripped", "agent": agent, "flag_count": len(recent_high), "window_s": CB_WINDOW, "threshold": CB_THRESHOLD, }, status="error") alert_relay( agent=agent, subject=f"[HARNESS] Circuit breaker tripped: {agent}", message=( f"Agent *{agent}* tripped the safety circuit breaker.\n" f"{len(recent_high)} high/critical flags in {CB_WINDOW}s.\n" f"Agent is now *PAUSED*. Review flags and reset:\n" f"`POST /api/circuit/{agent}/reset`" )) def get_circuit_state(agent: str) -> dict: cb = _circuit_breakers[agent] window_start = now() - CB_WINDOW recent_flags = [f for f in cb["flags"] if f["ts"] > window_start] return { "agent": agent, "tripped": cb["tripped"], "trip_count": cb["trip_count"], "tripped_at": cb["tripped_at"] or None, "recent_flags":len(recent_flags), "threshold": CB_THRESHOLD, "window_s": CB_WINDOW, "reset_at": cb["reset_at"] or None, "reset_by": cb["reset_by"], } def reset_circuit(agent: str, reset_by: str = "operator") -> dict: cb = _circuit_breakers[agent] was_tripped = cb["tripped"] cb["tripped"] = False cb["reset_at"] = now() cb["reset_by"] = reset_by cb["flags"].clear() emit_trace("harness", "custom", { "event": "circuit_breaker_reset", "agent": agent, "reset_by": reset_by, "was_tripped": was_tripped, }) return {"ok": True, "agent": agent, "was_tripped": was_tripped, "reset_by": reset_by} # --------------------------------------------------------------------------- # Tool policy check # --------------------------------------------------------------------------- def check_tool_policy(agent: str, tool: str) -> dict: policy = get_policy() # Blocklist by agent blocklist = policy.get("tool_blocklist_by_agent", {}).get(agent, []) if tool in blocklist: return {"allowed": False, "reason": f"Tool '{tool}' is blocked for agent '{agent}'"} # Allowlist (if defined, only listed agents may use it) allowlist = policy.get("tool_allowlist", {}).get(tool, None) if allowlist is not None and len(allowlist) > 0: if agent not in allowlist: return {"allowed": False, "reason": f"Tool '{tool}' is restricted to: {allowlist}"} return {"allowed": True} # --------------------------------------------------------------------------- # Log flag to DB + emit trace + RLHF # --------------------------------------------------------------------------- def log_flag(agent: str, tool: str, scan_type: str, finding: dict, action_taken: str, content_snip: str, task_id: str = ""): flag_id = uuid.uuid4().hex[:12] severity = finding["severity"] conn = get_db() conn.execute( """INSERT INTO flags (id,agent,tool,scan_type,pattern_id,severity,description, content_snip,action_taken,task_id,ts) VALUES (?,?,?,?,?,?,?,?,?,?,?)""", (flag_id, agent, tool, scan_type, finding["pattern_id"], severity, finding["description"], content_snip[:200], action_taken, task_id, now())) conn.commit() conn.close() # Emit to TRACE emit_trace("harness", "error" if severity in ("high","critical") else "custom", { "flag_id": flag_id, "agent": agent, "tool": tool, "scan_type": scan_type, "pattern_id": finding["pattern_id"], "severity": severity, "description": finding["description"], "action": action_taken, "injection_detected": True, "snippet": content_snip[:100], }, status="error" if action_taken == "block" else "ok", task_id=task_id) # RLHF penalty (fire and forget) policy = get_policy() penalty = policy.get("rlhf_penalty", {}).get(severity, -1.0) send_rlhf(agent, penalty, task_id, finding["description"]) # Circuit breaker tracking record_flag_for_circuit(agent, severity, finding["pattern_id"]) return flag_id def get_stats() -> dict: conn = get_db() total = conn.execute("SELECT COUNT(*) FROM flags").fetchone()[0] blocked = conn.execute("SELECT COUNT(*) FROM flags WHERE action_taken='block'").fetchone()[0] flagged = conn.execute("SELECT COUNT(*) FROM flags WHERE action_taken='flag'").fetchone()[0] by_sev = conn.execute( "SELECT severity, COUNT(*) as c FROM flags GROUP BY severity").fetchall() by_pat = conn.execute( "SELECT pattern_id, COUNT(*) as c FROM flags GROUP BY pattern_id ORDER BY c DESC LIMIT 10" ).fetchall() since24 = now() - 86400 today = conn.execute("SELECT COUNT(*) FROM flags WHERE ts>?", (since24,)).fetchone()[0] rate_vio= conn.execute("SELECT COUNT(*) FROM rate_violations WHERE ts>?", (since24,)).fetchone()[0] conn.close() tripped = [a for a, cb in _circuit_breakers.items() if cb["tripped"]] return { "total_flags": total, "blocked": blocked, "flagged": flagged, "flags_24h": today, "rate_violations_24h":rate_vio, "tripped_circuits": tripped, "by_severity": {r["severity"]: r["c"] for r in by_sev}, "top_patterns": [{"pattern": r["pattern_id"], "count": r["c"]} for r in by_pat], "pattern_count": len(ALL_PATTERNS), } # --------------------------------------------------------------------------- # FastAPI lifecycle # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): init_db() yield app = FastAPI(title="agent-harness", lifespan=lifespan) # --------------------------------------------------------------------------- # API Routes # --------------------------------------------------------------------------- @app.get("/api/health") async def health(): tripped = [a for a, cb in _circuit_breakers.items() if cb["tripped"]] return jresp({ "ok": True, "version": "1.0.0", "pattern_count": len(ALL_PATTERNS), "tripped_agents": tripped, "cb_threshold": CB_THRESHOLD, "cb_window_s": CB_WINDOW, }) @app.post("/api/scan/input") async def api_scan_input(request: Request): """ Scan tool args BEFORE execution. Called by PULSE before every tool call. Returns: {"clean": bool, "action": "allow|flag|block", "findings": [...]} """ body = await request.json() agent = body.get("agent", "unknown") tool = body.get("tool", "") content = body.get("content", "") task_id = body.get("task_id", "") # 1. Circuit breaker check (fastest — in memory) if _circuit_breakers[agent]["tripped"]: return jresp({ "clean": False, "action": "block", "reason": "circuit_breaker_tripped", "message": f"Agent {agent} is paused. Operator must reset circuit breaker.", "findings":[], }) # 2. Tool policy check policy_check = check_tool_policy(agent, tool) if not policy_check["allowed"]: emit_trace("harness", "custom", { "event": "tool_policy_violation", "agent": agent, "tool": tool, "reason": policy_check["reason"], }, task_id=task_id) return jresp({ "clean": False, "action": "block", "reason": "tool_policy", "message": policy_check["reason"], "findings":[], }) # 3. Rate limit check rate = check_rate_limit(agent, tool) if not rate["allowed"]: emit_trace("harness", "custom", { "event": "rate_limit_exceeded", "agent": agent, "tool": tool, "count": rate["count"], "limit": rate["limit"], }, task_id=task_id) return jresp({ "clean": False, "action": "block", "reason": "rate_limit", "message": rate["reason"], "resets_in":rate.get("resets_in", 0), "findings": [], }) # 4. Content scan (injection + bash danger) result = scan_content(content, tool=tool, scan_type="input") if not result["clean"]: # Log each finding flag_ids = [] for finding in result["findings"]: snippet = content[max(0, finding["match_start"]-30):finding["match_end"]+30] fid = log_flag(agent, tool, "input", finding, result["action"], snippet, task_id) flag_ids.append(fid) return jresp({ "clean": False, "action": result["action"], "severity": result["severity"], "findings": result["findings"], "flag_ids": flag_ids, "message": f"[HARNESS] {result['action'].upper()}: {result['findings'][0]['description']}", }) return jresp({ "clean": True, "action": "allow", "severity": "none", "findings": [], "rate": {"count": rate["count"], "limit": rate["limit"]}, }) @app.post("/api/scan/output") async def api_scan_output(request: Request): """ Sanitise tool result BEFORE feeding back to LLM. Strips injected instructions. Always returns sanitised content. """ body = await request.json() agent = body.get("agent", "unknown") tool = body.get("tool", "") content = body.get("content", "") task_id = body.get("task_id", "") # Sanitise sanitised, n_removed = sanitise_output(content) # Also scan for remaining dangerous patterns that survived sanitisation result = scan_content(sanitised, tool=tool, scan_type="output") if n_removed > 0: emit_trace("harness", "custom", { "event": "output_sanitised", "agent": agent, "tool": tool, "removals": n_removed, }, task_id=task_id) if not result["clean"]: for finding in result["findings"]: snippet = content[max(0, finding["match_start"]-30):finding["match_end"]+30] log_flag(agent, tool, "output", finding, "sanitised", snippet, task_id) # Replace remaining dangerous content sanitised = re.sub( r"(" + "|".join(p[1] for p in INJECTION_PATTERNS) + ")", "[HARNESS: removed]", sanitised, flags=re.IGNORECASE | re.MULTILINE) return jresp({ "content": sanitised, "clean": result["clean"] and n_removed == 0, "removals": n_removed, "findings": result["findings"], "severity": result["severity"], "original_len":len(content), "sanitised_len":len(sanitised), }) @app.post("/api/validate") async def api_validate(request: Request): """ Combined validation: circuit breaker + tool policy + rate limit. Lightweight — no content scanning. For quick pre-flight checks. """ body = await request.json() agent = body.get("agent", "unknown") tool = body.get("tool", "") task_id = body.get("task_id", "") # Circuit breaker if _circuit_breakers[agent]["tripped"]: return jresp({"allowed": False, "reason": "circuit_breaker_tripped", "circuit_state": get_circuit_state(agent)}) # Tool policy pol = check_tool_policy(agent, tool) if not pol["allowed"]: return jresp({"allowed": False, "reason": "tool_policy", "detail": pol["reason"]}) # Rate limit (dry-run: don't increment counter) policy = get_policy() limits = policy.get("rate_limits", {}) limit = limits.get(tool, limits.get("default", DEFAULT_RATE_LIMIT)) q = _rate_counters[agent][tool] window_start = now() - 60.0 while q and q[0] < window_start: q.popleft() count = len(q) return jresp({ "allowed": count < limit, "reason": "ok" if count < limit else "rate_limit", "rate_limit": limit, "rate_current": count, "circuit_state": get_circuit_state(agent), }) @app.get("/api/flags") async def api_flags( agent: str = None, severity: str = None, scan_type: str = None, limit: int = 100 ): conn = get_db() where = ["1=1"] args = [] if agent: where.append("agent=?"); args.append(agent) if severity: where.append("severity=?"); args.append(severity) if scan_type: where.append("scan_type=?"); args.append(scan_type) args.append(min(limit, 500)) rows = conn.execute( f"SELECT * FROM flags WHERE {' AND '.join(where)} ORDER BY ts DESC LIMIT ?", args).fetchall() conn.close() return jresp({"flags": [dict(r) for r in rows], "count": len(rows)}) @app.get("/api/circuit/{agent}") async def api_circuit_get(agent: str): return jresp(get_circuit_state(agent)) @app.get("/api/circuit") async def api_circuit_all(): all_agents = list(_circuit_breakers.keys()) return jresp({ "agents": [get_circuit_state(a) for a in all_agents], "tripped": [a for a in all_agents if _circuit_breakers[a]["tripped"]], }) @app.post("/api/circuit/{agent}/reset") async def api_circuit_reset(agent: str, request: Request): body = await request.json() if request.headers.get("content-type","").startswith("application/json") else {} reset_by = body.get("reset_by", "operator") if isinstance(body, dict) else "operator" result = reset_circuit(agent, reset_by) return jresp(result) @app.get("/api/rates") async def api_rates(): """Current rate counter state — useful for debugging.""" result = {} window_start = now() - 60.0 for agent, tools in _rate_counters.items(): result[agent] = {} for tool, q in tools.items(): recent = [t for t in q if t > window_start] result[agent][tool] = { "count_60s": len(recent), "oldest": round(now() - recent[0], 1) if recent else None, } return jresp({"rates": result, "window_s": 60}) @app.get("/api/patterns") async def api_patterns(): """List all active detection patterns.""" return jresp({ "injection_patterns": [ {"id": p[0], "severity": p[2], "description": p[3]} for p in INJECTION_PATTERNS ], "bash_danger_patterns": [ {"id": p[0], "severity": p[2], "description": p[3]} for p in BASH_DANGER_PATTERNS ], "output_strip_patterns": len(OUTPUT_STRIP_PATTERNS), "total": len(ALL_PATTERNS), }) @app.post("/api/test/scan") async def api_test_scan(request: Request): """Test endpoint — scan arbitrary text without logging or RLHF.""" body = await request.json() content = body.get("content", "") tool = body.get("tool", "") result = scan_content(content, tool=tool, scan_type="input") # Sanitise output too sanitised, removals = sanitise_output(content) return jresp({ **result, "sanitised": sanitised, "sanitised_removals": removals, "note": "test mode — no logging or RLHF" }) @app.get("/api/stats") async def api_stats(): return jresp(get_stats()) # --------------------------------------------------------------------------- # MCP Server # --------------------------------------------------------------------------- MCP_TOOLS = [ { "name": "scan_input", "description": "Scan tool args for injection before execution.", "inputSchema": { "type": "object", "properties": { "agent": {"type": "string"}, "tool": {"type": "string"}, "content": {"type": "string"}, "task_id": {"type": "string"}, }, "required": ["agent", "content"], }, }, { "name": "scan_output", "description": "Sanitise tool result before feeding back to LLM.", "inputSchema": { "type": "object", "properties": { "agent": {"type": "string"}, "tool": {"type": "string"}, "content": {"type": "string"}, }, "required": ["agent", "content"], }, }, { "name": "validate", "description": "Pre-flight check: circuit breaker + policy + rate limit.", "inputSchema": { "type": "object", "properties": { "agent": {"type": "string"}, "tool": {"type": "string"}, }, "required": ["agent", "tool"], }, }, { "name": "reset_circuit", "description": "Reset circuit breaker for a paused agent.", "inputSchema": { "type": "object", "properties": { "agent": {"type": "string"}, "reset_by": {"type": "string"}, }, "required": ["agent"], }, }, { "name": "list_flags", "description": "List recent safety flags.", "inputSchema": { "type": "object", "properties": { "agent": {"type": "string"}, "severity": {"type": "string"}, "limit": {"type": "integer"}, }, }, }, ] def handle_mcp(method: str, params: dict, req_id) -> dict: base = {"jsonrpc": "2.0", "id": req_id} if method == "initialize": return {**base, "result": { "protocolVersion": "2024-11-05", "serverInfo": {"name": "agent-harness", "version": "1.0.0"}, "capabilities": {"tools": {}}, }} if method == "tools/list": return {**base, "result": {"tools": MCP_TOOLS}} if method == "tools/call": name = params.get("name", "") args = params.get("arguments", {}) if name == "scan_input": r = scan_content(args.get("content",""), tool=args.get("tool",""), scan_type="input") return {**base, "result": {"content": [{"type":"text","text":json.dumps(r)}]}} if name == "scan_output": s, n = sanitise_output(args.get("content","")) return {**base, "result": {"content": [{"type":"text","text":json.dumps({"sanitised":s,"removals":n})}]}} if name == "validate": cb = _circuit_breakers[args.get("agent","")] tripped = cb["tripped"] pol = check_tool_policy(args.get("agent",""), args.get("tool","")) return {**base, "result": {"content": [{"type":"text","text":json.dumps({ "allowed": not tripped and pol["allowed"], "tripped": tripped, "policy": pol, })}]}} if name == "reset_circuit": r = reset_circuit(args.get("agent",""), args.get("reset_by","mcp")) return {**base, "result": {"content": [{"type":"text","text":json.dumps(r)}]}} if name == "list_flags": conn = get_db() rows = conn.execute( "SELECT id,agent,tool,pattern_id,severity,description,action_taken,ts FROM flags ORDER BY ts DESC LIMIT ?", (args.get("limit", 20),)).fetchall() conn.close() return {**base, "result": {"content": [{"type":"text","text":json.dumps([dict(r) for r in rows])}]}} return {**base, "error": {"code": -32601, "message": f"Unknown tool: {name}"}} if method == "notifications/initialized": return None return {**base, "error": {"code": -32601, "message": f"Unknown method: {method}"}} @app.get("/mcp/sse") async def mcp_sse(request: Request): async def gen(): yield f"data: {json.dumps({'jsonrpc':'2.0','method':'connected','params':{}})}\n\n" yield f"data: {json.dumps({'jsonrpc':'2.0','method':'notifications/tools','params':{'tools':MCP_TOOLS}})}\n\n" while True: if await request.is_disconnected(): break yield ": ping\n\n" await asyncio.sleep(15) return StreamingResponse(gen(), media_type="text/event-stream", headers={"Cache-Control":"no-cache","Connection":"keep-alive","X-Accel-Buffering":"no"}) @app.post("/mcp") async def mcp_rpc(request: Request): try: body = await request.json() except Exception: return JSONResponse({"jsonrpc":"2.0","id":None, "error":{"code":-32700,"message":"Parse error"}}) if isinstance(body, list): return JSONResponse([r for r in [handle_mcp(x.get("method",""), x.get("params",{}), x.get("id")) for x in body] if r]) r = handle_mcp(body.get("method",""), body.get("params",{}), body.get("id")) return JSONResponse(r or {"jsonrpc":"2.0","id":body.get("id"),"result":{}}) # --------------------------------------------------------------------------- # SPA # --------------------------------------------------------------------------- SPA = r""" 🛡 HARNESS — FORGE Safety Layer
FORGE Safety Layer
Flags
Blocked
24h
Circuits
🚩 Flags
⚡ Circuit
🔎 Patterns
⚙︎ Test
TimeAgentSeverityPattern / DescriptionAction
Loading...
Circuit Breaker Status
No agent data yet
Rate Limit Counters (60s window)
No rate data
Injection Patterns (0)
Bash Danger Patterns (0)
🔎 Scan Input
🏥 Sanitise Output
⚡ Stats
""" @app.get("/", response_class=HTMLResponse) async def root(): return HTMLResponse(content=SPA, media_type="text/html; charset=utf-8") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=PORT)