"""KPI tree data structure."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Iterator

from src.config import KPINodeConfig, NodeType, KPICategory


@dataclass
class KPINode:
    """KPI tree node."""

    id: str
    name: str
    node_type: NodeType
    formula: str | None = None
    unit: str = ""
    category: KPICategory = KPICategory.FINANCE
    description: str = ""

    # Runtime values
    value: float | None = None
    target: float | None = None

    # Tree structure
    children: list[KPINode] = field(default_factory=list)
    parent: KPINode | None = field(default=None, repr=False)

    @property
    def achievement_rate(self) -> float | None:
        """Calculate achievement rate (actual / target * 100)."""
        if self.value is None or self.target is None:
            return None
        if self.target == 0:
            return None
        return self.value / self.target * 100

    @property
    def gap(self) -> float | None:
        """Calculate gap from target."""
        if self.value is None or self.target is None:
            return None
        return self.value - self.target

    @property
    def is_calculated(self) -> bool:
        """Check if this node has a formula."""
        return self.formula is not None

    @property
    def is_leaf(self) -> bool:
        """Check if this is a leaf node (no children)."""
        return len(self.children) == 0

    @property
    def is_root(self) -> bool:
        """Check if this is the root node (no parent)."""
        return self.parent is None

    @property
    def depth(self) -> int:
        """Calculate depth from root."""
        if self.parent is None:
            return 0
        return self.parent.depth + 1

    def add_child(self, child: KPINode) -> None:
        """Add a child node."""
        child.parent = self
        self.children.append(child)

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "id": self.id,
            "name": self.name,
            "type": self.node_type.value,
            "formula": self.formula,
            "unit": self.unit,
            "category": self.category.value,
            "value": self.value,
            "target": self.target,
            "achievement_rate": self.achievement_rate,
            "gap": self.gap,
            "children": [c.to_dict() for c in self.children],
        }


@dataclass
class KPITree:
    """KPI tree structure."""

    root: KPINode
    nodes: dict[str, KPINode] = field(default_factory=dict)

    @classmethod
    def from_config(cls, node_configs: list[KPINodeConfig]) -> KPITree:
        """Build tree from configuration.

        Args:
            node_configs: List of node configurations.

        Returns:
            Constructed KPI tree.

        Raises:
            ValueError: If tree structure is invalid.
        """
        # Step 1: Create all nodes
        nodes: dict[str, KPINode] = {}
        for cfg in node_configs:
            node = KPINode(
                id=cfg.id,
                name=cfg.name,
                node_type=cfg.type,
                formula=cfg.formula,
                unit=cfg.unit,
                category=cfg.category,
                description=cfg.description,
            )
            nodes[cfg.id] = node

        # Step 2: Build parent-child relationships
        for cfg in node_configs:
            node = nodes[cfg.id]
            for child_id in cfg.children:
                if child_id not in nodes:
                    raise ValueError(f"Child node '{child_id}' not found for '{cfg.id}'")
                child = nodes[child_id]
                node.add_child(child)

        # Step 3: Find root (KGI node)
        root_nodes = [n for n in nodes.values() if n.node_type == NodeType.KGI]
        if not root_nodes:
            raise ValueError("No KGI node found")
        if len(root_nodes) > 1:
            raise ValueError("Multiple KGI nodes found")

        root = root_nodes[0]

        return cls(root=root, nodes=nodes)

    def set_values(self, values: dict[str, float]) -> None:
        """Set actual values for nodes.

        Args:
            values: Dictionary mapping node ID to value.
        """
        for node_id, value in values.items():
            if node_id in self.nodes:
                self.nodes[node_id].value = value

    def set_targets(self, targets: dict[str, float]) -> None:
        """Set target values for nodes.

        Args:
            targets: Dictionary mapping node ID to target.
        """
        for node_id, target in targets.items():
            if node_id in self.nodes:
                self.nodes[node_id].target = target

    def get_node(self, node_id: str) -> KPINode | None:
        """Get node by ID."""
        return self.nodes.get(node_id)

    def get_leaves(self) -> list[KPINode]:
        """Get all leaf nodes (INPUT nodes)."""
        return [n for n in self.nodes.values() if n.is_leaf]

    def get_calculated_nodes(self) -> list[KPINode]:
        """Get all nodes with formulas."""
        return [n for n in self.nodes.values() if n.is_calculated]

    def iter_breadth_first(self) -> Iterator[KPINode]:
        """Iterate nodes in breadth-first order."""
        queue = [self.root]
        while queue:
            node = queue.pop(0)
            yield node
            queue.extend(node.children)

    def iter_depth_first(self) -> Iterator[KPINode]:
        """Iterate nodes in depth-first order."""
        stack = [self.root]
        while stack:
            node = stack.pop()
            yield node
            stack.extend(reversed(node.children))

    def to_dict(self) -> dict[str, Any]:
        """Convert tree to dictionary."""
        return self.root.to_dict()

    def to_text(self, indent: int = 2) -> str:
        """Convert tree to text representation."""
        lines: list[str] = []

        def _render(node: KPINode, depth: int = 0) -> None:
            prefix = " " * (depth * indent)
            value_str = f"{node.value:,.0f}" if node.value is not None else "-"
            unit_str = node.unit if node.unit else ""

            lines.append(f"{prefix}- {node.name} ({node.id}): {value_str} {unit_str}")

            if node.formula:
                lines.append(f"{prefix}  Formula: {node.formula}")

            for child in node.children:
                _render(child, depth + 1)

        _render(self.root)
        return "\n".join(lines)

    def validate(self) -> list[str]:
        """Validate tree structure.

        Returns:
            List of validation errors.
        """
        errors: list[str] = []

        # Check for orphan nodes
        visited = set()

        def visit(node: KPINode) -> None:
            visited.add(node.id)
            for child in node.children:
                visit(child)

        visit(self.root)

        orphans = set(self.nodes.keys()) - visited
        if orphans:
            errors.append(f"Orphan nodes found: {orphans}")

        # Check formula references
        for node in self.nodes.values():
            if node.formula:
                import re

                refs = re.findall(r"\{(\w+)\}", node.formula)
                for ref in refs:
                    if ref not in self.nodes:
                        errors.append(f"Unknown variable '{ref}' in formula for '{node.id}'")

        return errors
