from crewai import Crew, Task, Process
import time
import asyncio
import traceback
from typing import Dict, Any, List, Optional
from agents.query_understanding_agent import QueryUnderstandingAgent
from agents.rag_retrieval_agent import RAGRetrievalAgent
from agents.response_generation_agent import ResponseGenerationAgent
from utils.vector_db_interface import VectorDBInterface
from utils.chroma_db import ChromaDBManager
from utils.logger import Logger, app_logger
from utils.task_status import TaskStatus
from utils.async_task_manager import AsyncTaskManager
from utils.optimized_conversation_memory import optimized_conversation_memory as conversation_memory
from utils.conversation_handler import conversation_handler
from utils.enhanced_response_generator import enhanced_response_generator

class RAGCrew:
    """
    A crew of agents that work together to provide RAG-based question answering.
    """
    
    def __init__(self, vector_db: Optional[VectorDBInterface] = None, task_manager: Optional[AsyncTaskManager] = None, logger: Optional[Logger] = None):
        """
        Initialize the RAGCrew
        
        Args:
            vector_db: Optional vector database interface
            task_manager: Optional task manager for tracking progress
            logger: Optional logger instance
        """
        # Use the provided vector database or create a default ChromaDB instance
        self.vector_db = vector_db if vector_db else ChromaDBManager()
        self.logger = logger or app_logger
        self.task_manager = task_manager
        
        # Agent names for progress tracking
        self.agent_names = [
            "QueryUnderstandingAgent",
            "RAGRetrievalAgent",
            "ResponseGenerationAgent"
        ]
        
        # Lazy initialization flag
        self._agents_initialized = False
        
        # Public properties that will be lazily initialized
        self.query_understanding = None
        self.rag_retrieval = None
        self.response_generation = None
        self.query_understanding_agent = None
        self.rag_retrieval_agent = None
        self.response_generation_agent = None
        self.crew = None
        
        self.logger.info("RAGCrew initialized with lazy loading enabled")
    
    def _initialize_agents(self):
        """
        Lazy initialization of agents when needed
        """
        # Prevent infinite recursion
        if self._agents_initialized:
            return
            
        # Set flag first to prevent recursion
        self._agents_initialized = True
        
        self.logger.info("Initializing agents for the first time")
        
        # Initialize agent classes with shared logger
        self.query_understanding = QueryUnderstandingAgent(logger=self.logger)
        self.rag_retrieval = RAGRetrievalAgent(vector_db=self.vector_db, logger=self.logger)
        self.response_generation = ResponseGenerationAgent(logger=self.logger)
        
        # Initialize crewAI agents
        self.query_understanding_agent = self.query_understanding.create_agent()
        self.rag_retrieval_agent = self.rag_retrieval.create_agent()
        self.response_generation_agent = self.response_generation.create_agent()
        
        # Create tasks first without calling _get_tasks to avoid recursion
        understand_query_task = Task(
            description="Analyze the user's query to extract key information for retrieval",
            expected_output="A detailed analysis of the query including keywords, entities, and topics",
            agent=self.query_understanding_agent
        )
        
        retrieve_information_task = Task(
            description="Retrieve relevant information from the knowledge base based on the query analysis",
            expected_output="A set of relevant documents and passages from the knowledge base",
            agent=self.rag_retrieval_agent,
            context=[understand_query_task]
        )
        
        generate_response_task = Task(
            description="Generate a comprehensive and accurate response based on the retrieved information",
            expected_output="A well-crafted response that addresses the user's query",
            agent=self.response_generation_agent,
            context=[understand_query_task, retrieve_information_task]
        )
        
        tasks = [
            understand_query_task,
            retrieve_information_task,
            generate_response_task
        ]
        
        # Create the crew
        self.crew = Crew(
            agents=[
                self.query_understanding_agent,
                self.rag_retrieval_agent,
                self.response_generation_agent
            ],
            tasks=tasks,
            verbose=True,
            process=Process.sequential  # Tasks will be executed in sequence
        )
        
        self.logger.info("Agents initialized successfully")
    
    # _get_tasks method removed to prevent recursion issues - tasks are now created directly in _initialize_agents
    
    async def process_query(self, query: str, conversation_id: Optional[str] = None, is_new_conversation: bool = False, task_id: Optional[str] = None) -> Dict[str, Any]:
        """
        Process a user query through the agent crew using async methods
        
        Args:
            query: The user's question or query
            conversation_id: Optional conversation ID for history tracking
            is_new_conversation: Whether this is a new conversation (first message)
            task_id: Optional task ID for progress tracking
            
        Returns:
            Dict[str, Any]: The final response to the user with metadata
        """
        self.logger.info(f"Processing query: {query[:50]}..." if len(query) > 50 else query)
        start_time = time.time()
        
        try:
            # Fast path for simple greetings to improve initial response time
            query_lower = query.lower().strip()
            simple_greetings = ["hi", "hello", "hey", "good morning", "good afternoon", "good evening", "hi there", "hello there"]
            
            if any(query_lower == greeting or query_lower.startswith(greeting + " ") for greeting in simple_greetings):
                self.logger.info("Detected simple greeting, using fast path response")
                
                # Prepare greeting response based on time of day
                import datetime
                current_hour = datetime.datetime.now().hour
                greeting = "Good morning" if 5 <= current_hour < 12 else "Good afternoon" if 12 <= current_hour < 17 else "Good evening"
                
                response_text = f"{greeting}! Welcome to MangoIT Solutions. How can I assist you today?"
                
                # Add greeting to conversation history if needed
                if conversation_id:
                    try:
                        conversation_handler.add_assistant_message(
                            conversation_id=conversation_id,
                            content=response_text
                        )
                        self.logger.info(f"Added greeting response to conversation history for {conversation_id[:8]}")
                    except Exception as e:
                        self.logger.error(f"Error adding greeting to conversation: {str(e)}")
                
                # Update agent statuses if task_id is provided
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.query_understanding.name, TaskStatus.SKIPPED)
                    self.task_manager.update_agent_status(task_id, self.rag_retrieval.name, TaskStatus.SKIPPED)
                    self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.COMPLETED)
                
                execution_time = time.time() - start_time
                self.logger.info(f"Greeting processed in {execution_time:.2f}s using fast path")
                
                return {
                    "response": response_text,
                    "sources": [],
                    "execution_time": execution_time
                }
            
            # Initialize context with the query, conversation ID, and new conversation flag
            context = {
                "query": query,
                "conversation_id": conversation_id,
                "is_new_conversation": is_new_conversation
            }
            
            # Initialize agents only if needed for non-greeting queries
            if not self._agents_initialized:
                self.logger.info("Lazy initializing agents for non-greeting query")
                self._initialize_agents()
            
            # Step 1: Query Understanding (async)
            if self.task_manager and task_id:
                self.task_manager.update_agent_status(task_id, self.query_understanding.name, TaskStatus.RUNNING)
                
            self.logger.info(f"Running {self.query_understanding.name}")
            # Use async method
            context = await self.query_understanding.arun(context)
            
            if "error" in context and context["error"]:
                self.logger.error(f"Error in {self.query_understanding.name}: {context['error']}")
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.query_understanding.name, TaskStatus.FAILED, context["error"])
                return self._create_error_response(f"Error understanding query: {context['error']}")
            
            if self.task_manager and task_id:
                self.task_manager.update_agent_status(task_id, self.query_understanding.name, TaskStatus.COMPLETED)
            
            # Get query analysis and check for routing
            query_analysis = context.get("query_analysis", {})
            route = query_analysis.get("route", "rag")
            
            # Handle different routing paths
            if route == "smalltalk":
                # Skip retrieval for smalltalk
                self.logger.info("Detected smalltalk, skipping retrieval")
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.rag_retrieval.name, TaskStatus.SKIPPED)
                    self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.RUNNING)
                
                # Generate smalltalk response directly using enhanced response generator
                self.logger.info(f"Running enhanced response generator for smalltalk")
                conversation_history = self._get_conversation_history(conversation_id)
                response = await enhanced_response_generator.generate_response(
                    query=query, results={"results": [], "route": "smalltalk"}, conversation_history=conversation_history
                )
                
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.COMPLETED)
                
                execution_time = time.time() - start_time
                self.logger.info(f"Smalltalk query processed in {execution_time:.2f}s")
                
                return {
                    "response": response.get("response", "Hello! How can I help you today?"),
                    "sources": [],
                    "execution_time": execution_time
                }
            
            elif route == "abuse":
                # Handle abuse without retrieval
                self.logger.info("Detected abuse, skipping retrieval")
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.rag_retrieval.name, TaskStatus.SKIPPED)
                    self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.RUNNING)
                
                # Generate abuse response using enhanced response generator
                self.logger.info(f"Running enhanced response generator for abuse handling")
                response = await enhanced_response_generator.generate_response(
                    query=query, results={"results": [], "route": "abuse"}, conversation_history=""
                )
                
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.COMPLETED)
                
                execution_time = time.time() - start_time
                self.logger.info(f"Abuse query processed in {execution_time:.2f}s")
                
                return {
                    "response": response.get("response", "I'm here to help with your questions about MangoIT Solutions."),
                    "sources": [],
                    "execution_time": execution_time
                }
            
            # Standard RAG path
            # Step 2: RAG Retrieval (must use await since it's async)
            if self.task_manager and task_id:
                self.task_manager.update_agent_status(task_id, self.rag_retrieval.name, TaskStatus.RUNNING)
                
            self.logger.info(f"Running {self.rag_retrieval.name}")
            context = await self.rag_retrieval.run(context)
            
            if "error" in context and context["error"]:
                self.logger.error(f"Error in {self.rag_retrieval.name}: {context['error']}")
                if self.task_manager and task_id:
                    self.task_manager.update_agent_status(task_id, self.rag_retrieval.name, TaskStatus.FAILED, context["error"])
                return self._create_error_response(f"Error retrieving information: {context['error']}")
            
            if self.task_manager and task_id:
                self.task_manager.update_agent_status(task_id, self.rag_retrieval.name, TaskStatus.COMPLETED)
            
            # Step 3: Response Generation (async)
            if self.task_manager and task_id:
                self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.RUNNING)
                
            self.logger.info("Running enhanced response generator")
            # Use async method
            conversation_history = self._get_conversation_history(conversation_id)
            
            # Add route information to retrieved_info
            retrieved_info = context.get("retrieved_info", {})
            if isinstance(retrieved_info, dict):
                retrieved_info["route"] = route
            else:
                # If retrieved_info is not a dict, create a new dict with route
                retrieved_info = {"results": retrieved_info if retrieved_info else [], "route": route}
                
            self.logger.info(f"Using route '{route}' for response generation")
            # Pass is_new_conversation flag to response generator
            is_new = context.get("is_new_conversation", False)
            self.logger.info(f"Is new conversation: {is_new}")
            response = await enhanced_response_generator.generate_response(
                query=query, results=retrieved_info, conversation_history=conversation_history, is_new_conversation=is_new
            )
            
            # Add response to conversation history if needed
            if conversation_id and response and "response" in response:
                try:
                    # Use conversation handler to add assistant response
                    conversation_handler.add_assistant_message(
                        conversation_id=conversation_id,
                        content=response["response"]
                    )
                    self.logger.info(f"{self.response_generation.name}: Added response to conversation history using handler")
                except Exception as history_error:
                    self.logger.error(f"Error adding to conversation history: {str(history_error)}")
                    self.logger.error(f"Error details: {traceback.format_exc()}")
                    
                    # Fallback to direct conversation memory as a backup
                    try:
                        conversation_memory.add_message(
                            conversation_id=conversation_id,
                            content=response["response"],
                            role="assistant"
                        )
                        self.logger.info(f"{self.response_generation.name}: Added response using fallback method")
                    except Exception as fallback_error:
                        self.logger.error(f"Fallback also failed: {str(fallback_error)}")
                        self.logger.error(f"Fallback error details: {traceback.format_exc()}")
            
            if self.task_manager and task_id:
                self.task_manager.update_agent_status(task_id, self.response_generation.name, TaskStatus.COMPLETED)
            
            # Extract the final response
            execution_time = time.time() - start_time
            
            self.logger.info(f"Query processed successfully in {execution_time:.2f}s")
            
            return {
                "response": response.get("response", "No response generated"),
                "sources": response.get("sources", []),
                "execution_time": execution_time
            }
            
        except Exception as e:
            self.logger.error(f"Error in process_query: {str(e)}", exc_info=True)
            return self._create_error_response(f"I encountered an issue while processing your query: {str(e)}")
    
    def _get_conversation_history(self, conversation_id: Optional[str] = None) -> str:
        """Get conversation history as a string"""
        if not conversation_id:
            return ""
        
        try:
            # Use conversation handler to get history
            history = conversation_handler.get_conversation_history(conversation_id)
            self.logger.info(f"Retrieved conversation history using handler for {conversation_id[:8]}")
            return history
        except Exception as e:
            self.logger.error(f"Error getting conversation history with handler: {str(e)}")
            self.logger.debug(traceback.format_exc())
            
            # Fallback to direct conversation memory
            try:
                history = conversation_memory.get_context_string(conversation_id)
                self.logger.info(f"Retrieved conversation history using fallback for {conversation_id[:8]}")
                return history
            except Exception as fallback_error:
                self.logger.error(f"Fallback history retrieval failed: {str(fallback_error)}")
                self.logger.debug(traceback.format_exc())
                return ""
    
    def _create_error_response(self, error_message: str) -> Dict[str, Any]:
        """Create a standardized error response"""
        return {
            "response": error_message,
            "sources": [],
            "error": error_message
        }
