from pathlib import Path
from PIL import Image
import numpy as np
import cv2
import uuid

def _make_curved_strip(
    overlay_rgba,
    target_w,
    target_h,
    edge_ratio=0.22,
    squeeze=0.72
):
    """
    Simulates cylindrical curvature by squeezing left/right edges
    and stretching back to full width.
    """
    overlay = overlay_rgba.resize((target_w, target_h), Image.LANCZOS)

    lw = int(target_w * edge_ratio)
    rw = lw
    cw = max(1, target_w - lw - rw)

    left = overlay.crop((0, 0, lw, target_h))
    center = overlay.crop((lw, 0, lw + cw, target_h))
    right = overlay.crop((lw + cw, 0, target_w, target_h))

    left = left.resize((max(1, int(left.width * squeeze)), target_h), Image.LANCZOS)
    right = right.resize((max(1, int(right.width * squeeze)), target_h), Image.LANCZOS)

    total_w = left.width + center.width + right.width
    curved = Image.new("RGBA", (total_w, target_h), (0, 0, 0, 0))

    x = 0
    curved.paste(left, (x, 0), left); x += left.width
    curved.paste(center, (x, 0), center); x += center.width
    curved.paste(right, (x, 0), right)

    return curved.resize((target_w, target_h), Image.LANCZOS)

def _apply_back_mug_squeeze(img, squeeze_factor=0.82):
    """
    Applies horizontal perspective compression for back/right mug.
    Keeps image vertically straight.
    """
    w, h = img.size
    new_w = int(w * squeeze_factor)

    squeezed = img.resize((new_w, h), Image.LANCZOS)

    canvas = Image.new("RGBA", (w, h), (0, 0, 0, 0))
    canvas.paste(squeezed, (0, 0), squeezed)

    return canvas

def generate_mug_final_image(
    base_image_path,
    mask_path,
    overlay_path,
    output_dir,
    slice_ranges=None,
    edge_ratio=0.22,
    squeeze=0.72,
    back_squeeze_factor=0.82,
    mask_thresh=10
):
    """
    Correct mug compositor:
    - WIDTH-wise slicing of overlay
    - Full coverage of print area
    - Cylindrical curvature preserved
    - Back mug uses perspective squeeze (NO rotation)
    """

    # Default wrap logic: front → middle → back
    if slice_ranges is None:
        slice_ranges = [(0, 40), (30, 80), (65, 100)]

    base = Image.open(base_image_path).convert("RGBA")
    mask = Image.open(mask_path).convert("L")
    overlay = Image.open(overlay_path).convert("RGBA")

    if mask.size != base.size:
        mask = mask.resize(base.size, Image.LANCZOS)

    W, H = mask.size
    mask_np = np.array(mask)

    # Connected components
    _, binary = cv2.threshold(mask_np, mask_thresh, 255, cv2.THRESH_BINARY)
    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)

    regions = []
    for lbl in range(1, num_labels):
        x, y, w, h, area = stats[lbl]
        if area < 20:
            continue
        cx = centroids[lbl][0]
        regions.append((cx, x, y, w, h))

    # Left → Right ordering
    regions.sort(key=lambda r: r[0])

    result = Image.new("RGBA", (W, H), (0, 0, 0, 0))
    ow, oh = overlay.size

    for i, (_, x, y, w, h) in enumerate(regions):
        left_pct, right_pct = (
            slice_ranges[i] if i < len(slice_ranges) else slice_ranges[-1]
        )

        lx = int((left_pct / 100.0) * ow)
        rx = int((right_pct / 100.0) * ow)
        rx = max(lx + 1, min(rx, ow))

        # WIDTH-wise slice
        overlay_slice = overlay.crop((lx, 0, rx, oh))

        # Fill printable area completely
        overlay_slice = overlay_slice.resize((w, h), Image.LANCZOS)

        # Cylindrical curvature
        curved = _make_curved_strip(
            overlay_slice,
            w,
            h,
            edge_ratio=edge_ratio,
            squeeze=squeeze
        )

        # Back/right mug: perspective squeeze only
        if i == len(regions) - 1:
            curved = _apply_back_mug_squeeze(
                curved,
                squeeze_factor=back_squeeze_factor
            )

        result.paste(curved, (x, y), curved)

    # Clamp alpha by original mask
    result_np = np.array(result)
    result_np[:, :, 3] = (
        result_np[:, :, 3].astype(np.float32) *
        (mask_np.astype(np.float32) / 255.0)
    ).astype(np.uint8)

    final = Image.alpha_composite(
        base,
        Image.fromarray(result_np, "RGBA")
    )

    out_path = Path(output_dir) / f"{uuid.uuid4().hex}_final.png"
    final.convert("RGB").save(out_path, quality=98)

    return str(out_path)

def paste_by_mask_resize(base, mask, overlay):
    base = base.convert("RGBA")
    overlay = overlay.convert("RGBA")
    mask = mask.convert("L")

    if mask.size != base.size:
        mask = mask.resize(base.size, Image.LANCZOS)

    base_np = np.array(base).astype(np.float32)
    overlay_np = np.array(overlay).astype(np.float32)
    mask_np = np.array(mask).astype(np.float32) / 255.0

    coords = cv2.findNonZero((mask_np > 0.1).astype(np.uint8))
    if coords is None:
        raise ValueError("Mask empty")

    x, y, w, h = cv2.boundingRect(coords)

    overlay_resized = Image.fromarray(overlay_np.astype(np.uint8)).resize((w, h), Image.LANCZOS)
    overlay_np = np.array(overlay_resized).astype(np.float32)

    result = base_np.copy()
    m = mask_np[y:y+h, x:x+w]
    m4 = np.repeat(m[:, :, None], 4, axis=2)

    result[y:y+h, x:x+w] = result[y:y+h, x:x+w] * (1 - m4) + overlay_np * m4
    return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8), "RGBA")

def generate_final_image_from_mask(
    base_image_path,
    mask_path,
    overlay_path,
    output_dir,
    unique_name=None
):
    name = (unique_name or "").lower()
    is_mug = "mug" in name or "cup" in name

    if is_mug:
        return generate_mug_final_image(
            base_image_path,
            mask_path,
            overlay_path,
            output_dir,
        )

    # NON-MUG — leave untouched
    base = Image.open(base_image_path).convert("RGBA")
    mask = Image.open(mask_path).convert("L")
    overlay = Image.open(overlay_path).convert("RGB")

    final = paste_by_mask_resize(base, mask, overlay)
    out = Path(output_dir) / f"{uuid.uuid4().hex}_final.png"
    final.convert("RGB").save(out, quality=98)
    return str(out)
