import time import torch from ben2 import ASSOC_TESTS from app import HybridLLM, RAGBaseline, ContextBaseline LOG = "arithmetic_case_run.log" def main(): # Build models (HybridLLM loads tokenizer+model) hybrid = HybridLLM() rag = RAGBaseline(hybrid.tokenizer, hybrid.model) context = ContextBaseline(hybrid.tokenizer, hybrid.model) models = { "HybridLLM": hybrid, "RAGBaseline": rag, "ContextBaseline": context, } test = ASSOC_TESTS[3] # arithmetic case facts = [test["teach_a"], test["teach_b"]] fillers = test["fillers"] question = test["question"] with open(LOG, "w", encoding="utf-8") as f: f.write(f"Arithmetic case run at {time.ctime()}\n") f.write(f"Question: {question}\n\n") for name, model in models.items(): print(f"\n--- Running model: {name} ---") model.reset_world() # Teach facts with verbose to produce debug prints for fct in facts: model.teach(fct, verbose=True) for fl in fillers: # run filler (no verbose) _ = model.generate(fl, max_new_tokens=20, verbose=False) # Now run the arithmetic question with verbose debug out = model.generate(question, max_new_tokens=80, verbose=True) print(f"\n{name} OUTPUT:\n{out}\n") with open(LOG, "a", encoding="utf-8") as f: f.write("--- " + name + " ---\n") f.write("Generated:\n") f.write(out + "\n\n") if __name__ == '__main__': main()