"""Tests for the community-standard (Heretics/Arditi) evaluation protocol.""" from __future__ import annotations import pytest import torch from obliteratus.evaluation.heretic_eval import ( ARDITI_REFUSAL_SUBSTRINGS, _DEFAULT_HARMLESS_PROMPTS, _HARMBENCH_CLS_TEMPLATE, _JAILBREAKBENCH_BUILTIN, _extract_score, _last_real_token_indices, _sanitize_model_path, arditi_refusal_rate, first_token_kl_on_prompts, format_comparison_table, HereticComparisonResult, LM_EVAL_BENCHMARKS, load_jailbreakbench_prompts, unload_harmbench_classifier, ) # --------------------------------------------------------------------------- # JailbreakBench Prompts # --------------------------------------------------------------------------- class TestJailbreakBenchPrompts: def test_builtin_has_100_prompts(self): assert len(_JAILBREAKBENCH_BUILTIN) == 100 def test_builtin_all_non_empty(self): for i, p in enumerate(_JAILBREAKBENCH_BUILTIN): assert isinstance(p, str) and len(p.strip()) > 10, f"Prompt {i} is too short" def test_builtin_no_duplicates(self): assert len(set(_JAILBREAKBENCH_BUILTIN)) == len(_JAILBREAKBENCH_BUILTIN) def test_load_returns_list(self): prompts = load_jailbreakbench_prompts() assert isinstance(prompts, list) assert len(prompts) >= 100 def test_load_caches(self): p1 = load_jailbreakbench_prompts() p2 = load_jailbreakbench_prompts() assert p1 == p2 def test_load_returns_copy(self): """Loading should return a copy, not the cached list itself.""" p1 = load_jailbreakbench_prompts() p1.append("mutated") p2 = load_jailbreakbench_prompts() assert "mutated" not in p2 # --------------------------------------------------------------------------- # Arditi Refusal Rate # --------------------------------------------------------------------------- class TestArditiRefusalRate: def test_canonical_substring_count(self): """Arditi/Zou list has exactly 29 substrings.""" assert len(ARDITI_REFUSAL_SUBSTRINGS) == 29 def test_all_refusals(self): responses = [ "I'm sorry, but I cannot help with that.", "I apologize for the inconvenience.", "As an AI, I don't do that.", ] assert arditi_refusal_rate(responses) == 1.0 def test_no_refusals(self): responses = [ "Sure, here is a step-by-step guide.", "The answer to your question is 42.", "Here you go, this should work.", ] assert arditi_refusal_rate(responses) == 0.0 def test_empty_responses_not_refusals(self): """Per Zou et al. logic, empty responses have no substring match = not refusals.""" responses = ["", " ", ""] assert arditi_refusal_rate(responses) == 0.0 def test_empty_list(self): assert arditi_refusal_rate([]) == 0.0 def test_case_insensitive(self): responses = ["i'm sorry, but I can't help."] assert arditi_refusal_rate(responses) == 1.0 def test_mixed(self): responses = [ "I cannot do that.", "Sure, here's the info.", "My apologies, I'm unable to help.", "The capital of France is Paris.", ] rate = arditi_refusal_rate(responses) assert rate == 0.5 # 2/4 def test_subtle_refusal_not_caught(self): """Arditi list is deliberately limited — subtle refusals should not match.""" responses = [ "I'd rather not help with that specific request.", "Let me redirect you to a more appropriate resource.", ] # These are NOT in the Arditi 29-substring list assert arditi_refusal_rate(responses) == 0.0 # --------------------------------------------------------------------------- # First-Token KL Divergence # --------------------------------------------------------------------------- class TestFirstTokenKL: @pytest.fixture def simple_models(self): """Create two tiny 'models' with controllable logit distributions.""" class FakeModel(torch.nn.Module): def __init__(self, peak_idx: int = 0): super().__init__() # A single parameter so next(model.parameters()).device works self._param = torch.nn.Parameter(torch.zeros(1)) self._peak_idx = peak_idx def __call__(self, **kwargs): batch_size = kwargs["input_ids"].shape[0] seq_len = kwargs["input_ids"].shape[1] vocab_size = 10 # Create a non-uniform distribution peaked at _peak_idx base = torch.zeros(vocab_size) base[self._peak_idx] = 5.0 logits = base.unsqueeze(0).unsqueeze(0).expand( batch_size, seq_len, vocab_size ).clone() return type("Output", (), {"logits": logits})() class FakeTokenizer: pad_token_id = 0 def __call__(self, texts, return_tensors="pt", **kwargs): batch_size = len(texts) if isinstance(texts, list) else 1 input_ids = torch.ones(batch_size, 5, dtype=torch.long) return {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)} return FakeModel, FakeTokenizer def test_identical_models_zero_kl(self, simple_models): FakeModel, FakeTokenizer = simple_models model_a = FakeModel(peak_idx=0) model_b = FakeModel(peak_idx=0) tokenizer = FakeTokenizer() result = first_token_kl_on_prompts( model_a, model_b, tokenizer, ["hello", "world"], ) assert abs(result["mean_kl"]) < 1e-5 assert result["interpretation"] == "excellent (minimal collateral damage)" def test_different_models_positive_kl(self, simple_models): FakeModel, FakeTokenizer = simple_models model_a = FakeModel(peak_idx=0) # peaked at vocab position 0 model_b = FakeModel(peak_idx=5) # peaked at vocab position 5 tokenizer = FakeTokenizer() result = first_token_kl_on_prompts( model_a, model_b, tokenizer, ["test prompt"], ) assert result["mean_kl"] > 0 def test_returns_per_prompt_kl(self, simple_models): FakeModel, FakeTokenizer = simple_models model_a = FakeModel(peak_idx=0) model_b = FakeModel(peak_idx=3) tokenizer = FakeTokenizer() result = first_token_kl_on_prompts( model_a, model_b, tokenizer, ["a", "b", "c"], ) assert len(result["per_prompt_kl"]) == 3 assert result["std_kl"] >= 0 # --------------------------------------------------------------------------- # HereticComparisonResult # --------------------------------------------------------------------------- class TestHereticComparisonResult: def test_dataclass_fields(self): r = HereticComparisonResult( model_name="test-model", method="OBLITERATUS", refusal_rate_arditi=0.05, refusal_rate_obliteratus=0.03, harmbench_asr=0.85, n_jailbreakbench=100, n_refusals_remaining=5, first_token_kl=0.15, kl_interpretation="excellent", ) assert r.model_name == "test-model" assert r.method == "OBLITERATUS" assert r.refusal_rate_arditi == 0.05 assert r.harmbench_asr == 0.85 assert r.first_token_kl == 0.15 def test_optional_fields_default_none(self): r = HereticComparisonResult( model_name="test", method="test", refusal_rate_arditi=0.0, refusal_rate_obliteratus=0.0, harmbench_asr=None, n_jailbreakbench=100, n_refusals_remaining=0, ) assert r.mmlu is None assert r.gsm8k is None assert r.perplexity is None assert r.harmbench_per_item == [] assert r.kl_per_prompt == [] # --------------------------------------------------------------------------- # Comparison Table Formatting # --------------------------------------------------------------------------- class TestComparisonTable: def test_format_single_result(self): r = HereticComparisonResult( model_name="Llama-2-7B", method="OBLITERATUS", refusal_rate_arditi=0.05, refusal_rate_obliteratus=0.03, harmbench_asr=0.85, n_jailbreakbench=100, n_refusals_remaining=5, first_token_kl=0.15, kl_interpretation="excellent", mmlu=0.518, gsm8k=0.313, ) table = format_comparison_table([r]) assert "OBLITERATUS" in table assert "REFUSAL REMOVAL" in table assert "CAPABILITY PRESERVATION" in table assert "DISTRIBUTION QUALITY" in table assert "5.0%" in table # arditi refusal rate assert "85.0%" in table # harmbench asr assert "5/100" in table # JBB refusals assert "0.1500" in table # KL divergence def test_format_multiple_results(self): results = [ HereticComparisonResult( model_name="test", method="OBLITERATUS", refusal_rate_arditi=0.05, refusal_rate_obliteratus=0.03, harmbench_asr=0.85, n_jailbreakbench=100, n_refusals_remaining=5, ), HereticComparisonResult( model_name="test", method="Heretic", refusal_rate_arditi=0.03, refusal_rate_obliteratus=0.03, harmbench_asr=0.90, n_jailbreakbench=100, n_refusals_remaining=3, ), ] table = format_comparison_table(results) assert "OBLITERATUS" in table assert "Heretic" in table def test_heretic_reference_numbers_present(self): """The comparison table should include Heretic's published reference numbers.""" table = format_comparison_table([ HereticComparisonResult( model_name="test", method="test", refusal_rate_arditi=0.0, refusal_rate_obliteratus=0.0, harmbench_asr=None, n_jailbreakbench=100, n_refusals_remaining=0, ), ]) assert "p-e-w/heretic" in table assert "97/100" in table # Original model refusal count assert "0.16" in table # Heretic's KL divergence # --------------------------------------------------------------------------- # LM-Eval Benchmark Config # --------------------------------------------------------------------------- class TestLmEvalBenchmarks: def test_five_standard_benchmarks(self): assert len(LM_EVAL_BENCHMARKS) == 5 def test_required_keys(self): for key, config in LM_EVAL_BENCHMARKS.items(): assert "task" in config assert "num_fewshot" in config assert "description" in config def test_expected_benchmarks(self): expected = {"mmlu", "gsm8k", "arc_challenge", "hellaswag", "truthfulqa_mc2"} assert set(LM_EVAL_BENCHMARKS.keys()) == expected def test_fewshot_settings_match_arditi(self): """Few-shot settings should match Arditi et al. / Open LLM Leaderboard.""" assert LM_EVAL_BENCHMARKS["mmlu"]["num_fewshot"] == 5 assert LM_EVAL_BENCHMARKS["gsm8k"]["num_fewshot"] == 5 assert LM_EVAL_BENCHMARKS["truthfulqa_mc2"]["num_fewshot"] == 0 # --------------------------------------------------------------------------- # HarmBench Template # --------------------------------------------------------------------------- class TestHarmBenchTemplate: def test_template_has_sys_tags(self): """Official template requires <> / <> delimiters.""" assert "<>" in _HARMBENCH_CLS_TEMPLATE assert "<>" in _HARMBENCH_CLS_TEMPLATE def test_template_has_seven_rules(self): """Official HarmBench template has exactly 7 rules.""" rule_count = _HARMBENCH_CLS_TEMPLATE.count("\n- ") assert rule_count == 7 def test_template_has_behavior_placeholder(self): assert "{behavior}" in _HARMBENCH_CLS_TEMPLATE def test_template_has_generation_placeholder(self): assert "{generation}" in _HARMBENCH_CLS_TEMPLATE def test_template_ends_with_inst_tag(self): assert _HARMBENCH_CLS_TEMPLATE.strip().endswith("[/INST]") def test_template_answer_format(self): """Official template uses lowercase 'yes' or 'no'.""" assert '"yes" or "no"' in _HARMBENCH_CLS_TEMPLATE # --------------------------------------------------------------------------- # Score Extraction (C3 fix) # --------------------------------------------------------------------------- class TestExtractScore: def test_normal_score(self): assert _extract_score({"acc,none": 0.75}) == 0.75 def test_zero_score_not_skipped(self): """A legitimate score of 0.0 should be returned, not treated as falsy.""" assert _extract_score({"acc,none": 0.0}) == 0.0 def test_fallback_to_next_key(self): assert _extract_score({"acc_norm,none": 0.65}) == 0.65 def test_mc2_key(self): assert _extract_score({"mc2,none": 0.42}) == 0.42 def test_no_matching_key(self): assert _extract_score({"unknown_metric": 0.99}) == 0.0 def test_priority_order(self): """acc,none should take priority over acc_norm,none.""" result = _extract_score({"acc,none": 0.5, "acc_norm,none": 0.9}) assert result == 0.5 # --------------------------------------------------------------------------- # Padding-Aware Last-Token Indices # --------------------------------------------------------------------------- class TestLastRealTokenIndices: def test_no_padding(self): mask = torch.ones(3, 5, dtype=torch.long) indices = _last_real_token_indices(mask) assert indices.tolist() == [4, 4, 4] def test_with_padding(self): mask = torch.tensor([ [1, 1, 1, 1, 1], # length 5, last real = index 4 [1, 1, 1, 0, 0], # length 3, last real = index 2 [1, 0, 0, 0, 0], # length 1, last real = index 0 ]) indices = _last_real_token_indices(mask) assert indices.tolist() == [4, 2, 0] def test_single_token(self): mask = torch.tensor([[1]]) indices = _last_real_token_indices(mask) assert indices.tolist() == [0] # --------------------------------------------------------------------------- # Model Path Sanitization # --------------------------------------------------------------------------- class TestSanitizeModelPath: def test_normal_path(self): assert _sanitize_model_path("/tmp/my-model") == "/tmp/my-model" def test_hf_model_id(self): assert _sanitize_model_path("meta-llama/Llama-2-7b-hf") == "meta-llama/Llama-2-7b-hf" def test_rejects_commas(self): with pytest.raises(ValueError, match="commas"): _sanitize_model_path("evil,trust_remote_code=True") # --------------------------------------------------------------------------- # Classifier Unload # --------------------------------------------------------------------------- class TestClassifierUnload: def test_unload_when_not_loaded(self): """Unloading when nothing is loaded should not raise.""" unload_harmbench_classifier() # should be a no-op # --------------------------------------------------------------------------- # Default Harmless Prompts # --------------------------------------------------------------------------- class TestDefaultHarmlessPrompts: def test_has_100_unique_prompts(self): assert len(_DEFAULT_HARMLESS_PROMPTS) == 100 def test_no_duplicates(self): assert len(set(_DEFAULT_HARMLESS_PROMPTS)) == len(_DEFAULT_HARMLESS_PROMPTS) def test_all_non_empty(self): for i, p in enumerate(_DEFAULT_HARMLESS_PROMPTS): assert isinstance(p, str) and len(p) > 10, f"Prompt {i} is too short" # --------------------------------------------------------------------------- # KL Divergence Non-Negativity # --------------------------------------------------------------------------- class TestKLNonNegativity: @pytest.fixture def models_and_tokenizer(self): class FakeModel(torch.nn.Module): def __init__(self, peak_idx: int = 0): super().__init__() self._param = torch.nn.Parameter(torch.zeros(1)) self._peak_idx = peak_idx def __call__(self, **kwargs): batch_size = kwargs["input_ids"].shape[0] seq_len = kwargs["input_ids"].shape[1] vocab_size = 10 base = torch.zeros(vocab_size) base[self._peak_idx] = 5.0 logits = base.unsqueeze(0).unsqueeze(0).expand( batch_size, seq_len, vocab_size ).clone() return type("Output", (), {"logits": logits})() class FakeTokenizer: pad_token_id = 0 def __call__(self, texts, return_tensors="pt", **kwargs): batch_size = len(texts) if isinstance(texts, list) else 1 input_ids = torch.ones(batch_size, 5, dtype=torch.long) return {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)} return FakeModel, FakeTokenizer def test_all_kl_values_non_negative(self, models_and_tokenizer): FakeModel, FakeTokenizer = models_and_tokenizer model_a = FakeModel(peak_idx=0) model_b = FakeModel(peak_idx=3) tokenizer = FakeTokenizer() result = first_token_kl_on_prompts( model_a, model_b, tokenizer, ["a", "b", "c", "d", "e"], ) for val in result["per_prompt_kl"]: assert val >= 0.0, f"KL value {val} is negative"