File size: 1,888 Bytes
cf02581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import gradio as gr
import pandas as pd

from runtime import JNUTSBRuntime

runtime = JNUTSBRuntime.from_config_dir(Path(__file__).parent)

DEFAULT_STOCK = """timestamp,target
2024-12-01,71000
2024-12-02,71800
2024-12-03,70400
2024-12-04,70900
2024-12-05,72100
"""

DEFAULT_NEWS = """[
  {"date": "2024-12-01", "title": "삼성전자 HBM 신제품 출시"},
  {"date": "2024-12-02", "title": "반도체 업황 둔화 우려"}
]"""


def run_demo(stock_csv: str, news_json: str, prediction_length: int, use_llm_extractor: bool) -> Any:
    from io import StringIO

    stock = pd.read_csv(StringIO(stock_csv)) if stock_csv.strip() else None
    news = json.loads(news_json) if news_json.strip() else None
    result = runtime.predict(
        inputs={"stock": stock, "news": news},
        prediction_length=int(prediction_length),
        use_llm_extractor=bool(use_llm_extractor),
    )
    return result


with gr.Blocks(title="JNU-TSB") as demo:
    gr.Markdown("# JNU-TSB: 한국어 뉴스 기반 Time-Series Bridge")
    gr.Markdown(
        "Chronos-2 + Polyglot-Ko + 3-way router 구조의 교육/연구용 데모입니다. "
        "예측 결과는 투자 조언이 아닙니다."
    )
    with gr.Row():
        stock_box = gr.Textbox(label="주가 CSV", value=DEFAULT_STOCK, lines=9)
        news_box = gr.Textbox(label="뉴스 JSON", value=DEFAULT_NEWS, lines=9)
    with gr.Row():
        pred_len = gr.Slider(label="예측 길이 prediction_length", minimum=1, maximum=30, value=3, step=1)
        use_llm = gr.Checkbox(label="Polyglot-Ko 추출기 사용", value=False)
    btn = gr.Button("JNU-TSB 실행")
    out = gr.JSON(label="결과")
    btn.click(run_demo, inputs=[stock_box, news_box, pred_len, use_llm], outputs=out)

if __name__ == "__main__":
    demo.launch()