| import time
|
| import torch
|
|
|
| from ben2 import ASSOC_TESTS
|
| from app import HybridLLM, RAGBaseline, ContextBaseline
|
|
|
| LOG = "arithmetic_case_run.log"
|
|
|
| def main():
|
|
|
| hybrid = HybridLLM()
|
| rag = RAGBaseline(hybrid.tokenizer, hybrid.model)
|
| context = ContextBaseline(hybrid.tokenizer, hybrid.model)
|
|
|
| models = {
|
| "HybridLLM": hybrid,
|
| "RAGBaseline": rag,
|
| "ContextBaseline": context,
|
| }
|
|
|
| test = ASSOC_TESTS[3]
|
|
|
| facts = [test["teach_a"], test["teach_b"]]
|
| fillers = test["fillers"]
|
| question = test["question"]
|
|
|
| with open(LOG, "w", encoding="utf-8") as f:
|
| f.write(f"Arithmetic case run at {time.ctime()}\n")
|
| f.write(f"Question: {question}\n\n")
|
|
|
| for name, model in models.items():
|
| print(f"\n--- Running model: {name} ---")
|
| model.reset_world()
|
|
|
| for fct in facts:
|
| model.teach(fct, verbose=True)
|
| for fl in fillers:
|
|
|
| _ = model.generate(fl, max_new_tokens=20, verbose=False)
|
|
|
| out = model.generate(question, max_new_tokens=80, verbose=True)
|
| print(f"\n{name} OUTPUT:\n{out}\n")
|
| with open(LOG, "a", encoding="utf-8") as f:
|
| f.write("--- " + name + " ---\n")
|
| f.write("Generated:\n")
|
| f.write(out + "\n\n")
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|