# ruff: noqa: ANN401
# pyright: reportAny=false, reportExplicitAny=false
"""Tree walker for Pydantic models.

Creates a printout of a Pydantic model structure as a tree.
This is useful for visualizing the structure of complex data models.
"""

from collections.abc import Generator
from typing import Any, TypeVar, get_args, get_origin

from pydantic import BaseModel
from rich.console import Console
from rich.markup import escape
from rich.tree import Tree

console = Console()
COLORS = ["cyan", "green", "yellow", "orange", "red", "magenta", "blue"]
TYPE_COLOR = "bright_black"

Model = TypeVar("Model", bound=BaseModel)


def create_label(name: str, color: str, type_str: str | None = None) -> str:
    """Create a formatted label for the tree."""
    if type_str:
        return f"[{color}]{name}[/{color}]: [{TYPE_COLOR}]{type_str}[/{TYPE_COLOR}]"
    return f"[{color}]{name}[/{color}]"


def is_pydantic_model(type_: Any) -> bool:
    """Check if a type is a Pydantic model.

    This is useful for introspection and for handling Pydantic models
    differently from other types.
    """
    try:
        return issubclass(type_, BaseModel)
    except TypeError:
        return False


def get_model_fields(type_: type[Any]) -> Generator[tuple[str, Any]]:
    """Yield the fields and their types for a given model.

    This function supports Pydantic models, making it a
    key part of the model introspection process.
    """
    if is_pydantic_model(type_):
        for name, info in type_.model_fields.items():
            yield name, info.annotation


def get_type_name(type_: Any) -> str:
    """Get the name of a type, or its string representation if it has no name."""
    try:
        return type_.__name__
    except AttributeError:
        return str(type_)


def type_to_string(type_: type) -> str:
    """Convert a type to a string representation.

    This function can handle generic types like `list[int]` and `Union[str, int]`,
    which is essential for displaying complex type hints in a readable format.
    """
    origin = get_origin(type_)
    if origin is None:
        return get_type_name(type_)

    args_str = ", ".join(type_to_string(t) for t in get_args(type_))
    base_type = get_type_name(origin)
    return f"{base_type}[{args_str}]"


def extract_model_types(type_: type) -> Generator[type[Any]]:
    """Recursively extract model types from a composite type.

    For example, from `list[Union[ModelA, str, ModelB]]`, it will yield
    `ModelA` and `ModelB`. This is crucial for traversing nested models
    within generic containers.
    """
    if is_pydantic_model(type_):
        yield type_

    for arg in get_args(type_):
        yield from extract_model_types(arg)


def build_tree(model: type[Any], tree: Tree, level: int = 0) -> None:
    """Recursively build a tree representation of a model.

    This function traverses the model's fields, adding each field and its type
    to the tree. It handles nested models and generic types, ensuring a comprehensive
    representation of the model's structure.
    """
    for name, field_type in get_model_fields(model):
        color = COLORS[level % len(COLORS)]
        type_str = escape(type_to_string(field_type))
        label = create_label(name, color, type_str)
        child_tree = tree.add(label)
        model_types = list(extract_model_types(field_type))

        if len(model_types) == 1 and model_types[0] is field_type:
            build_tree(model_types[0], child_tree, level + 1)
            continue

        for model_type in model_types:
            sub_tree_label = create_label(model_type.__name__, COLORS[level + 1])
            sub_tree = child_tree.add(sub_tree_label)
            build_tree(model_type, sub_tree, level + 2)


def display_tree(model: type[Any]) -> None:
    """Print a colorful, tree-like representation of a model to the console.

    The indentation of each line indicates its depth in the model structure,
    making it easy to visualize the model's composition.
    """
    tree = Tree(create_label(model.__name__, COLORS[0]))
    build_tree(model, tree)
    console.print(tree)


if __name__ == "__main__":

    class InnerNode(BaseModel):
        """Inner node model for testing."""

        name: str
        id_number: int

    class AnotherInnerNode(BaseModel):
        """Another inner node model for testing."""

        age: int
        weight: float

    class Node(BaseModel):
        """Node model for testing."""

        name: str
        inner: InnerNode
        many_inner: list[InnerNode]
        union_inner: InnerNode | AnotherInnerNode
        many_union_inner: list[InnerNode | AnotherInnerNode]
        dict_union_inner: dict[str, InnerNode]

    class Root(BaseModel):
        """Root model for testing."""

        name: dict[str, InnerNode]
        child: Node

    display_tree(Root)
