"""워터폴 차트 생성"""
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

from ..data_models import ChartData
from ..pdf.styles import CHART_COLORS
from .combo_chart import setup_japanese_font


def create_waterfall_chart(chart_data: ChartData) -> BytesIO:
    """워터폴 차트 생성

    Args:
        chart_data: 차트 데이터
            datasets[0]: {"data": [...], "colors": ["increase", "decrease", "total", ...]}

    Returns:
        PNG 이미지 바이트 버퍼
    """
    setup_japanese_font()

    labels = chart_data.labels

    if not chart_data.datasets:
        return BytesIO()

    dataset = chart_data.datasets[0]
    values = dataset.get("data", [])
    color_types = dataset.get("colors", ["increase"] * len(values))

    # 색상 매핑
    color_map = {
        "increase": "#16a34a",   # 녹색 (증가)
        "decrease": "#dc2626",   # 빨간색 (감소)
        "total": "#2563eb",      # 파란색 (합계)
    }

    fig, ax = plt.subplots(figsize=(10, 5))

    x = np.arange(len(labels))
    width = 0.6

    # 누적 값 계산
    cumulative = 0
    bottoms = []
    bar_values = []
    bar_colors = []

    for i, (val, ctype) in enumerate(zip(values, color_types)):
        if ctype == "total":
            # 합계는 0부터 시작
            bottoms.append(0)
            bar_values.append(val)
        else:
            if val >= 0:
                bottoms.append(cumulative)
                bar_values.append(val)
                cumulative += val
            else:
                cumulative += val
                bottoms.append(cumulative)
                bar_values.append(abs(val))

        bar_colors.append(color_map.get(ctype, "#64748b"))

    # 막대 그리기
    bars = ax.bar(
        x,
        bar_values,
        width,
        bottom=bottoms,
        color=bar_colors,
        edgecolor="white",
        linewidth=1,
    )

    # 연결선 그리기
    for i in range(len(values) - 1):
        if color_types[i + 1] != "total":
            curr_top = bottoms[i] + bar_values[i]
            ax.plot(
                [x[i] + width / 2, x[i + 1] - width / 2],
                [curr_top, curr_top],
                color="#64748b",
                linestyle="--",
                linewidth=1,
                alpha=0.5,
            )

    # 값 레이블
    for i, (bar, val, ctype) in enumerate(zip(bars, values, color_types)):
        height = bar.get_height()
        y_pos = bar.get_y() + height / 2

        # 포맷팅
        if abs(val) >= 10000:
            label = f"{val/10000:.0f}万"
        else:
            label = f"{val:,.0f}"

        ax.annotate(
            label,
            xy=(bar.get_x() + bar.get_width() / 2, y_pos),
            ha="center",
            va="center",
            fontsize=9,
            fontweight="bold",
            color="white",
        )

    ax.set_xlabel("")
    ax.set_ylabel("金額")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right")

    # Y축 포맷
    ax.yaxis.set_major_formatter(
        plt.FuncFormatter(lambda x, p: f"{x/10000:.0f}万" if x != 0 else "0")
    )

    # 0 기준선
    ax.axhline(y=0, color="#64748b", linewidth=1)

    # 타이틀
    if chart_data.title:
        plt.title(chart_data.title, fontsize=12, fontweight="bold")

    # 그리드
    ax.grid(axis="y", alpha=0.3)
    ax.set_axisbelow(True)

    plt.tight_layout()

    # 이미지로 저장
    buffer = BytesIO()
    plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight")
    buffer.seek(0)
    plt.close(fig)

    return buffer
