"""
RAG router for RAG-based chat functionality
"""

import asyncio
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any, Optional

from models.rag import RAGChatRequest, RAGChatResponse, RAGChatStatusResponse
from core.dependencies import get_rag_crew, get_task_manager, get_logger, get_conversation_handler

router = APIRouter(prefix="/rag-chat", tags=["rag"])


@router.post("", response_model=RAGChatResponse)
async def rag_chat_endpoint(
    chat: RAGChatRequest,
    rag_crew=Depends(get_rag_crew),
    task_manager=Depends(get_task_manager),
    conversation_handler=Depends(get_conversation_handler),
    logger=Depends(get_logger)
):
    """Process a chat message using RAG-based agents"""
    try:
        logger.info(f"Received chat request: {chat.message[:50]}..." if len(chat.message) > 50 else chat.message)
        
        # Get or create conversation ID
        conversation_id = chat.conversation_id
        is_new_conversation = False
        
        if not conversation_id:
            # Create a new conversation using the conversation handler
            conversation_id = conversation_handler.create_conversation()
            is_new_conversation = True
            logger.info(f"Created new conversation with ID: {conversation_id[:8]}")
        else:
            # Check if conversation exists by trying to get its history
            history = conversation_handler.get_conversation_history(conversation_id)
            if not history:
                # Create a new conversation with the provided ID
                conversation_id = conversation_handler.create_conversation(conversation_id)
                is_new_conversation = True
                logger.info(f"Provided conversation ID not found, created new one: {conversation_id[:8]}")
            else:
                logger.info(f"Continuing conversation with ID: {conversation_id[:8]}")
        
        # Add user message to conversation history using the conversation handler
        conversation_handler.add_user_message(
            conversation_id=conversation_id,
            content=chat.message
        )
        
        # Get agent names for progress tracking
        agent_names = rag_crew.agent_names
        
        # Create an async task for the agent processing with agent names for progress tracking
        task_id = await task_manager.create_task(
            rag_crew.process_query(chat.message, conversation_id, is_new_conversation),
            agent_names=agent_names
        )
        
        logger.info(f"Created task {task_id[:8]} for processing query in conversation {conversation_id[:8]}")
        
        # Return a response with the task ID so the client can poll for results
        return RAGChatResponse(
            reply="Your request is being processed by our AI agents. You can check the status using the task ID.",
            task_id=task_id,
            sources=[],
            conversation_id=conversation_id
        )
    except Exception as e:
        logger.error(f"Error in rag_chat_endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/status/{task_id}", response_model=RAGChatStatusResponse)
async def check_rag_chat_status(
    task_id: str,
    task_manager=Depends(get_task_manager),
    logger=Depends(get_logger)
):
    """Check the status of a RAG chat task"""
    try:
        # Get the task status
        task_status = task_manager.get_task_status(task_id)
        
        if not task_status:
            raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
        
        # Calculate progress
        progress = task_status.get("progress", 0)
        
        # Get agent statuses
        agent_statuses = task_status.get("agent_statuses", [])
        
        # Return the status
        return RAGChatStatusResponse(
            status=task_status["status"],
            progress=progress,
            result=task_status.get("result"),
            agent_statuses=agent_statuses
        )
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error checking task status: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
