"""KPI calculation engine with safe formula evaluation."""

from __future__ import annotations

import ast
import re
from typing import Any

from .tree import KPITree, KPINode


class FormulaError(Exception):
    """Error in formula parsing or evaluation."""

    pass


class FormulaParser:
    """Safe formula parser using AST.

    Only allows basic arithmetic operations: +, -, *, /
    Variables are referenced as {variable_name}
    """

    # Allowed AST node types
    ALLOWED_NODES = (
        ast.Expression,
        ast.BinOp,
        ast.UnaryOp,
        ast.Constant,
        ast.Add,
        ast.Sub,
        ast.Mult,
        ast.Div,
        ast.USub,
        ast.UAdd,
    )

    def __init__(self) -> None:
        """Initialize parser."""
        self._variable_pattern = re.compile(r"\{(\w+)\}")

    def parse_variables(self, formula: str) -> list[str]:
        """Extract variable names from formula.

        Args:
            formula: Formula string with {variable} syntax.

        Returns:
            List of variable names.
        """
        return self._variable_pattern.findall(formula)

    def substitute_variables(self, formula: str, values: dict[str, float]) -> str:
        """Replace variables with values.

        Args:
            formula: Formula string with {variable} syntax.
            values: Dictionary mapping variable names to values.

        Returns:
            Formula with variables replaced by values.

        Raises:
            FormulaError: If a variable is missing.
        """
        result = formula

        for var in self.parse_variables(formula):
            if var not in values:
                raise FormulaError(f"Missing variable: {var}")
            result = result.replace(f"{{{var}}}", str(values[var]))

        return result

    def evaluate(self, formula: str, values: dict[str, float]) -> float:
        """Evaluate formula with given values.

        Args:
            formula: Formula string with {variable} syntax.
            values: Dictionary mapping variable names to values.

        Returns:
            Calculated result.

        Raises:
            FormulaError: If formula is invalid or evaluation fails.
        """
        # Substitute variables
        expression = self.substitute_variables(formula, values)

        # Parse to AST
        try:
            tree = ast.parse(expression, mode="eval")
        except SyntaxError as e:
            raise FormulaError(f"Invalid formula syntax: {e}")

        # Validate AST nodes
        self._validate_ast(tree)

        # Evaluate
        return self._eval_node(tree.body)

    def _validate_ast(self, tree: ast.AST) -> None:
        """Validate that AST only contains allowed node types.

        Args:
            tree: AST to validate.

        Raises:
            FormulaError: If tree contains disallowed nodes.
        """
        for node in ast.walk(tree):
            if not isinstance(node, self.ALLOWED_NODES):
                raise FormulaError(
                    f"Invalid expression: {type(node).__name__} is not allowed"
                )

    def _eval_node(self, node: ast.AST) -> float:
        """Evaluate a single AST node.

        Args:
            node: AST node to evaluate.

        Returns:
            Evaluated value.

        Raises:
            FormulaError: If node type is not supported.
        """
        if isinstance(node, ast.Constant):
            return float(node.value)

        if isinstance(node, ast.BinOp):
            left = self._eval_node(node.left)
            right = self._eval_node(node.right)

            if isinstance(node.op, ast.Add):
                return left + right
            if isinstance(node.op, ast.Sub):
                return left - right
            if isinstance(node.op, ast.Mult):
                return left * right
            if isinstance(node.op, ast.Div):
                if right == 0:
                    return 0.0  # Handle division by zero
                return left / right

        if isinstance(node, ast.UnaryOp):
            operand = self._eval_node(node.operand)
            if isinstance(node.op, ast.USub):
                return -operand
            if isinstance(node.op, ast.UAdd):
                return operand

        raise FormulaError(f"Unsupported node type: {type(node).__name__}")


class KPICalculator:
    """KPI calculation engine."""

    def __init__(self, tree: KPITree):
        """Initialize with KPI tree.

        Args:
            tree: KPI tree to calculate.
        """
        self.tree = tree
        self.parser = FormulaParser()

    def calculate(self) -> dict[str, float]:
        """Calculate all KPIs in the tree.

        Uses topological sort to ensure dependencies are calculated first.

        Returns:
            Dictionary mapping node ID to calculated value.
        """
        # Get calculation order (leaves first, root last)
        order = self._get_calculation_order()

        results: dict[str, float] = {}

        for node_id in order:
            node = self.tree.nodes[node_id]

            if node.formula:
                # Calculate from formula
                try:
                    value = self.parser.evaluate(node.formula, results)
                    node.value = value
                    results[node_id] = value
                except FormulaError as e:
                    # Log error but continue
                    print(f"Warning: Failed to calculate {node_id}: {e}")
                    results[node_id] = 0.0
            elif node.value is not None:
                # Use existing value
                results[node_id] = node.value
            else:
                # No value available
                results[node_id] = 0.0

        return results

    def _get_calculation_order(self) -> list[str]:
        """Get nodes in topological order for calculation.

        Returns nodes from leaves to root, ensuring all dependencies
        are calculated before dependents.

        Returns:
            List of node IDs in calculation order.
        """
        order: list[str] = []
        visited: set[str] = set()

        def visit(node: KPINode) -> None:
            if node.id in visited:
                return
            visited.add(node.id)

            # Visit children first (dependencies)
            for child in node.children:
                visit(child)

            order.append(node.id)

        visit(self.tree.root)
        return order

    def simulate(self, changes: dict[str, float]) -> dict[str, float]:
        """Simulate KPI changes without modifying the tree.

        Args:
            changes: Dictionary of node ID to new value.

        Returns:
            Dictionary mapping node ID to simulated value.
        """
        # Save current values
        saved_values = {n.id: n.value for n in self.tree.nodes.values()}

        # Apply changes
        for node_id, value in changes.items():
            if node_id in self.tree.nodes:
                self.tree.nodes[node_id].value = value

        # Calculate
        results = self.calculate()

        # Restore original values
        for node_id, value in saved_values.items():
            self.tree.nodes[node_id].value = value

        return results

    def get_formula_dependencies(self, node_id: str) -> list[str]:
        """Get all dependencies for a node's formula.

        Args:
            node_id: Node ID to check.

        Returns:
            List of node IDs that this node depends on.
        """
        node = self.tree.get_node(node_id)
        if not node or not node.formula:
            return []

        return self.parser.parse_variables(node.formula)
