from pathlib import Path
from typing import Dict, List, Tuple
import uuid

from PIL import Image, ImageOps, ImageChops
from psd_tools import PSDImage
from psd_tools.api.layers import Layer

from app.utils.image import create_thumbnail_from_path
from app.core.logging import setup_logger
from app.core.constants import KEYWORDS, LOGO_LAYER, COLOR_LAYER_KEYWORDS

logger = setup_logger(__name__)

BBox = Tuple[int, int, int, int]

def _safe_bbox(layer: Layer) -> BBox | None:
    """
    Safely extract a valid bounding box from a PSD layer.

    Returns:
        (x1, y1, x2, y2) if valid, otherwise None.
    """
    bbox = getattr(layer, "bbox", None)
    if not bbox:
        return None

    try:
        if isinstance(bbox, (tuple, list)):
            x1, y1, x2, y2 = map(int, bbox)
        else:
            x1, y1, x2, y2 = (
                int(bbox.x1),
                int(bbox.y1),
                int(bbox.x2),
                int(bbox.y2),
            )

        if x2 > x1 and y2 > y1:
            return x1, y1, x2, y2
    except Exception:
        pass

    return None

def _is_excluded_layer(name: str) -> bool:
    """
    Returns True if the layer name matches any excluded (logo/watermark) keywords.
    """
    lname = (name or "").lower()
    return any(k in lname for k in LOGO_LAYER)

def _is_logo_layer(layer: Layer) -> bool:
    name = getattr(layer, "name", "") or ""
    lname = name.lower()
    return any(k in lname for k in LOGO_LAYER)

def _build_pure_mask(psd, layers, canvas_size) -> Image.Image:
    """
    BLACK = printable, WHITE = non-printable
    """
    mask = Image.new("L", canvas_size, 255)
    for layer in layers:
        alpha = get_layer_alpha(psd, layer, canvas_size)
        mask = ImageChops.multiply(mask, ImageOps.invert(alpha))
    return mask


def _set_layer_visibility(psd_layers, visible_layers: set[Layer]):
    for layer in psd_layers:
        try:
            layer.visible = layer in visible_layers
        except Exception:
            pass

def detect_design_layers(psd: PSDImage) -> List[Tuple[Layer, BBox]]:
    """
    Detect candidate design layers inside a PSD file.

    Rules:
    - Only visible layers
    - Excludes logos / stamps
    - Matches smart objects OR keyword-based names
    - Removes parent container layers

    Returns:
        List of (layer, bbox) tuples.
    """
    candidates: List[Tuple[Layer, BBox]] = []

    for layer in psd.descendants():
        try:
            if hasattr(layer, "visible") and not layer.visible:
                continue

            name = getattr(layer, "name", "") or ""
            # Explicitly exclude logo / watermark layers
            if _is_excluded_layer(name):
                continue

            lname = name.lower()
            kind = getattr(layer, "kind", None)

            matched = (
                kind == "smartobject"
                or any(k in lname for k in KEYWORDS)
            )

            if not matched:
                continue

            bbox = _safe_bbox(layer)
            if bbox:
                candidates.append((layer, bbox))
        except Exception:
            continue

    # Remove parent container layers
    filtered: List[Tuple[Layer, BBox]] = []
    for i, (li, bbi) in enumerate(candidates):
        xi1, yi1, xi2, yi2 = bbi
        area_i = (xi2 - xi1) * (yi2 - yi1)
        is_parent = False

        for j, (_, bbj) in enumerate(candidates):
            if i == j:
                continue
            xj1, yj1, xj2, yj2 = bbj
            area_j = (xj2 - xj1) * (yj2 - yj1)

            if (
                xi1 <= xj1 and yi1 <= yj1
                and xi2 >= xj2 and yi2 >= yj2
                and area_i > area_j
            ):
                is_parent = True
                break

        if not is_parent:
            filtered.append((li, bbi))

    return filtered or candidates

def get_layer_alpha(psd: PSDImage, layer: Layer, canvas_size) -> Image.Image:
    """
    Render the alpha channel of a single PSD layer in isolation.

    Returns:
        Grayscale PIL Image (L mode), same size as canvas.
    """
    vis_state = [(l, getattr(l, "visible", True)) for l in psd.descendants()]

    try:
        # Hide all layers
        for l, _ in vis_state:
            try:
                l.visible = False
            except Exception:
                pass

        # Show only target layer
        try:
            layer.visible = True
        except Exception:
            pass

        rendered = psd.composite().convert("RGBA")
        alpha = rendered.split()[-1]

        if alpha.size != canvas_size:
            alpha = alpha.resize(canvas_size, Image.Resampling.LANCZOS)

        return alpha

    finally:
        # Restore original visibility
        for l, v in vis_state:
            try:
                l.visible = v
            except Exception:
                pass

def get_layer_rgba(psd: PSDImage, layer: Layer, canvas_size) -> Image.Image:
    vis_state = [(l, getattr(l, "visible", True)) for l in psd.descendants()]
    try:
        for l, _ in vis_state:
            try:
                l.visible = False
            except Exception:
                pass
        try:
            layer.visible = True
        except Exception:
            pass

        rendered = psd.composite().convert("RGBA")
        if rendered.size != canvas_size:
            rendered = rendered.resize(canvas_size, Image.Resampling.LANCZOS)
        return rendered
    finally:
        for l, v in vis_state:
            try:
                l.visible = v
            except Exception:
                pass

def generate_psd_assets(psd_path: Path, output_dir: Path) -> Dict[str, str]:
    """Generate derived image assets from a PSD template."""
    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info("Generating PSD assets: %s", psd_path)

    psd = PSDImage.open(psd_path)
    canvas_size = psd.size
    all_layers = list(psd.descendants())

    # --- classify layers ONCE ---
    design_layers = [l for l, _ in detect_design_layers(psd)]
    if not design_layers:
        raise RuntimeError("No design layers detected in PSD")

    logo_layers = [
        l for l in all_layers
        if getattr(l, "visible", True) and _is_logo_layer(l)
    ]

    color_layers = []
    for layer in all_layers:
        lname = (getattr(layer, "name", "") or "").lower()
        if any(k in lname for k in COLOR_LAYER_KEYWORDS) and _safe_bbox(layer):
            color_layers.append(layer)

    # snapshot visibility
    vis_state = {l: getattr(l, "visible", True) for l in all_layers}

    try:
        # --- BASE IMAGE ---
        hidden = set(design_layers) | set(logo_layers)
        _set_layer_visibility(all_layers, set(all_layers) - hidden)

        base_image = psd.composite().convert("RGBA")
        base_path = output_dir / (f"{uuid.uuid4().hex}_base_image.png")
        base_image.save(base_path)

        base_thumbnail_path = create_thumbnail_from_path(
            base_path, output_dir, (400, 400), preserve_aspect=True
        )

        # --- DESIGN MASKS ---
        pure_mask = _build_pure_mask(psd, design_layers, canvas_size)
        main_black_mask_path = output_dir / (f"{uuid.uuid4().hex}_main_black_mask.png")
        pure_mask.save(main_black_mask_path)

        main_white_mask_path = output_dir / (f"{uuid.uuid4().hex}_main_white_mask.png")
        ImageOps.invert(pure_mask).save(main_white_mask_path)

        # --- LOGO MASKS ---
        logo_black_mask_path = None
        logo_white_mask_path = None
        if logo_layers:
            logo_black_mask_path = output_dir / (f"{uuid.uuid4().hex}_logo_black_mask.png")
            logo_white_mask_path = output_dir / (f"{uuid.uuid4().hex}_logo_white_mask.png")
            logo_pure = _build_pure_mask(psd, logo_layers, canvas_size)
            logo_pure.save(logo_black_mask_path)
            ImageOps.invert(logo_pure).save(logo_white_mask_path)

        # --- COLOR LAYERS ---
        color_preview_path = None
        color_white_mask_path = None
        for layer in color_layers:
            if len(color_layers) > 1:
                raise RuntimeError("Only one color layer is supported")
            color_preview_path = output_dir / (f"{uuid.uuid4().hex}_color_preview.png")
            color_white_mask_path = output_dir / (f"{uuid.uuid4().hex}_color_white_mask.png")

            layer_rgba = get_layer_rgba(psd, layer, canvas_size)
            layer_rgba.save(color_preview_path)

            alpha = layer_rgba.split()[-1]
            alpha.point(lambda p: 255 if p > 0 else 0, mode="L").save(color_white_mask_path)

        return {
            "base_image": str(base_path),
            "base_thumbnail": str(base_thumbnail_path),
            "main_black_mask_path": str(main_black_mask_path),
            "main_white_mask_path": str(main_white_mask_path),
            "logo_black_mask_path": str(logo_black_mask_path) if logo_black_mask_path else None,
            "logo_white_mask_path": str(logo_white_mask_path) if logo_white_mask_path else None,
            "color_preview_path": str(color_preview_path) if color_preview_path else None,
            "color_white_mask_path": str(color_white_mask_path) if color_white_mask_path else None
        }

    finally:
        # restore visibility
        for layer, vis in vis_state.items():
            try:
                layer.visible = vis
            except Exception:
                pass
