from crewai import Agent
from utils.gemini_llm import GeminiLLM
import os
import time
import asyncio
from typing import Dict, Any, List, Optional
from utils.chroma_db import ChromaDBManager
from utils.vector_db_interface import VectorDBInterface
from utils.logger import app_logger
from utils.llm_query_analyzer import LLMQueryAnalyzer
from dotenv import load_dotenv
from logging import Logger

load_dotenv()

# Get the collection name from environment variable or use default
DEFAULT_COLLECTION = "mangoit_docs"
RAG_COLLECTION = os.getenv("RAG_COLLECTION", DEFAULT_COLLECTION)

# Collection name is set via environment variable

class RAGRetrievalAgent:
    """Agent responsible for retrieving relevant information from the vector database."""
    
    def __init__(self, model_name: str = "gemini-2.0-flash", vector_db: Optional[VectorDBInterface] = None, logger: Optional[Logger] = None):
        """Initialize the RAGRetrievalAgent
        
        Args:
            model_name: Name of the LLM model to use
            vector_db: Optional vector database interface
            logger: Optional logger instance
        """
        self.llm = GeminiLLM(
            model_name=model_name,
            temperature=0.2
        )
        # Use the provided vector database or create a default ChromaDB instance
        self.vector_db = vector_db if vector_db else ChromaDBManager()
        self.logger = logger or app_logger
        self.name = "RAGRetrievalAgent"
        
        # Initialize query analyzer
        self.query_analyzer = LLMQueryAnalyzer()
    
    def create_agent(self):
        return Agent(
            role="Information Retriever",
            goal="Retrieve the most relevant information from the knowledge base",
            backstory="""You are an expert at retrieving relevant information from 
            the knowledge base. You understand how to formulate effective queries 
            and can identify the most relevant information for a given question.""",
            verbose=True,
            llm=self.llm
        )
    
    async def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Run the agent to retrieve information
        
        Args:
            context: The context containing the query and query analysis
            
        Returns:
            Dict[str, Any]: The retrieved information and updated context
        """
        query = context.get("query", "")
        query_analysis = context.get("query_analysis", {})
        
        if not query:
            self.logger.error(f"{self.name}: No query provided in context")
            return {**context, "error": "No query provided"}
            
        if not query_analysis:
            self.logger.warning(f"{self.name}: No query analysis provided in context")
        
        self.logger.log_agent_start(self.name, f"Retrieving information for: {query[:50]}..." if len(query) > 50 else query)
        
        start_time = time.time()
        try:
            # Check for routing information
            route = query_analysis.get("route", "rag")
            query_type = query_analysis.get("query_type", "")
            
            # Skip retrieval for smalltalk or abuse routes
            if route in ["smalltalk", "abuse"] or query_type == "greeting":
                route_type = route if route in ["smalltalk", "abuse"] else "greeting"
                self.logger.info(f"{self.name}: Detected {route_type}, skipping retrieval")
                retrieved_info = {"results": [], "route": route_type}
                execution_time = time.time() - start_time
                self.logger.log_agent_complete(self.name, execution_time)
                return {**context, "retrieved_info": retrieved_info, "error": None}
            
            # Get the search query from the analysis or fall back to the original query
            search_query = query_analysis.get("search_query", query)
            
            # If search query is empty (for general conversation), skip retrieval
            if not search_query:
                self.logger.info(f"{self.name}: Empty search query, skipping retrieval")
                retrieved_info = {"results": []}
                execution_time = time.time() - start_time
                self.logger.log_agent_complete(self.name, execution_time)
                return {**context, "retrieved_info": retrieved_info, "error": None}
            
            # Use the search query from the query analysis
            # This may already be expanded by the QueryUnderstandingAgent
            self.logger.info(f"{self.name}: Using search query: {search_query}")
            
            # Check if this is a technology listing request
            is_tech_list_request = query_analysis.get("is_tech_list_request", False)
            
            # For technology listing requests, use a higher number of results
            n_results = 15 if is_tech_list_request else 5
            self.logger.info(f"{self.name}: Using {n_results} results for {'technology listing' if is_tech_list_request else 'standard'} query")
            
            # Prepare metadata filter if needed
            metadata_filter = None
            if is_tech_list_request:
                # For technology queries, filter by technology category
                metadata_filter = {"categories": ["web_development", "mobile", "ecommerce", "ai_ml"]}
            
            # For technology queries, also try specific technology keywords if the main query doesn't yield enough results
            if is_tech_list_request:
                # First try the search query
                retrieved_info = await self.retrieve_information(search_query, n_results=n_results, metadata_filter=metadata_filter)
                
                # If we don't have enough results, try with specific technology keywords
                if len(retrieved_info.get("results", [])) < 5:
                    tech_keywords = "wordpress php javascript react laravel python ai development"
                    self.logger.info(f"{self.name}: Not enough results, trying with specific technology keywords")
                    tech_results = await self.retrieve_information(tech_keywords, n_results=10, metadata_filter=metadata_filter)
                    
                    # Merge the results
                    all_results = retrieved_info.get("results", []) + tech_results.get("results", [])
                    
                    # Remove duplicates based on content
                    unique_results = []
                    seen_content = set()
                    for result in all_results:
                        content_hash = hash(result.get("content", "")[:100])  # Use first 100 chars as hash
                        if content_hash not in seen_content:
                            seen_content.add(content_hash)
                            unique_results.append(result)
                    
                    retrieved_info["results"] = unique_results[:n_results]
                    self.logger.info(f"{self.name}: Combined results count: {len(retrieved_info['results'])}")
                
                # Mark this as a technology listing for the response generator
                retrieved_info["is_tech_list_request"] = True
                
                execution_time = time.time() - start_time
                self.logger.log_agent_complete(self.name, execution_time)
                return {**context, "retrieved_info": retrieved_info, "error": None}
            
            # Check for multi-queries in the query analysis
            multi_queries = query_analysis.get("multi_queries", [])
            
            if multi_queries and len(multi_queries) > 1:
                # Use multi-query retrieval for better coverage
                self.logger.info(f"{self.name}: Using multi-query retrieval with {len(multi_queries)} queries")
                # Call the multi-query retrieval method
                retrieved_info = await self._multi_query_retrieval_async(multi_queries, n_results=n_results, metadata_filter=metadata_filter)
            else:
                # Standard retrieval for single query
                self.logger.info(f"{self.name}: Using standard retrieval")
                retrieved_info = await self.retrieve_information(search_query, metadata_filter=metadata_filter)
            
            execution_time = time.time() - start_time
            self.logger.log_agent_complete(self.name, execution_time)
            collection_name = getattr(self.vector_db, 'collection_name', 'default_collection')
            self.logger.info(f"{self.name}: Retrieved {len(retrieved_info.get('results', []))} documents from {collection_name}")
            
            # Update context with retrieved information
            return {**context, "retrieved_info": retrieved_info, "error": None}
            
        except Exception as e:
            execution_time = time.time() - start_time
            self.logger.log_agent_error(self.name, e)
            return {**context, "error": str(e)}
    
    async def _multi_query_retrieval_async(self, queries: List[str], n_results: int = 8, metadata_filter: Dict[str, Any] = None) -> Dict[str, Any]:
        """
        Retrieve information using multiple queries for better coverage (async version)
        
        Args:
            queries: List of search queries
            n_results: Number of results to return per query
            metadata_filter: Optional metadata filter
            
        Returns:
            Dict[str, Any]: Combined retrieved information
        """
        try:
            all_results = []
            
            # Use up to 6 queries to avoid excessive retrieval
            for i, query in enumerate(queries[:6]):
                self.logger.info(f"Multi-query [{i+1}/{min(6, len(queries))}]: {query}")
                
                # Retrieve results for this query with error handling
                try:
                    query_results = await self.retrieve_information(query, n_results=n_results, metadata_filter=metadata_filter)
                    if query_results and "results" in query_results:
                        all_results.extend(query_results["results"])
                except Exception as e:
                    self.logger.error(f"Error retrieving results for query '{query}': {str(e)}")
            
            # Remove duplicates based on content hash
            unique_results = []
            seen_content = set()
            
            for result in all_results:
                if "content" in result:
                    # Create a hash of the first 100 chars of content
                    content_hash = hash(result["content"][:100])
                    
                    if content_hash not in seen_content:
                        seen_content.add(content_hash)
                        unique_results.append(result)
            
            # Sort by relevance and limit to reasonable number
            unique_results = sorted(unique_results, key=lambda x: x.get("relevance", 0), reverse=True)
            final_results = unique_results[:min(15, len(unique_results))]
            
            # Create the final result structure
            collection_name = getattr(self.vector_db, 'collection_name', 'default_collection')
            return {
                "results": final_results,
                "query": ", ".join(queries[:3]) + ("..." if len(queries) > 3 else ""),
                "collection": collection_name
            }
            
        except Exception as e:
            self.logger.error(f"Error in multi-query retrieval: {str(e)}")
            return {"results": [], "query": "", "collection": ""}
    
    def _multi_query_retrieval(self, queries: List[str], n_results: int = 8, collection_name: str = None) -> Dict[str, Any]:
        """
        Retrieve information using multiple queries for better coverage
        
        Args:
            queries: List of search queries
            n_results: Number of results to return per query
            collection_name: Name of the collection to query
            
        Returns:
            Dict[str, Any]: Combined retrieved information
        """
        try:
            all_results = []
            
            # Use up to 6 queries to avoid excessive retrieval
            for i, query in enumerate(queries[:6]):
                self.logger.info(f"Multi-query [{i+1}/{min(6, len(queries))}]: {query}")
                
                # Retrieve results for this query with error handling
                try:
                    query_results = self.retrieve_information(query, n_results=n_results, collection_name=collection_name)
                except Exception as e:
                    self.logger.error(f"Error retrieving results for query '{query}': {str(e)}")
                    query_results = {"results": []}
                
                # Add to combined results
                if query_results and "results" in query_results:
                    all_results.extend(query_results["results"])
            
            # Remove duplicates based on content hash
            unique_results = []
            seen_content = set()
            
            for result in all_results:
                if "content" in result:
                    # Create a hash of the first 100 chars of content
                    content_hash = hash(result["content"][:100])
                    
                    if content_hash not in seen_content:
                        seen_content.add(content_hash)
                        unique_results.append(result)
            
            # Sort by relevance and limit to reasonable number
            unique_results = sorted(unique_results, key=lambda x: x.get("relevance", 0), reverse=True)
            final_results = unique_results[:min(15, len(unique_results))]
            
            self.logger.info(f"Multi-query retrieval: {len(all_results)} total results, {len(unique_results)} unique, {len(final_results)} final")
            
            # Return in the same format as retrieve_information
            return {
                "results": final_results,
                "query": ", ".join(queries[:3]) + ("..." if len(queries) > 3 else ""),
                "collection": collection_name or RAG_COLLECTION
            }
            
        except Exception as e:
            self.logger.error(f"Error in multi-query retrieval: {str(e)}")
            # Fall back to single query retrieval
            return self.retrieve_information(queries[0] if queries else "", n_results=n_results, collection_name=collection_name)
    
    async def retrieve_information(self, query: str, n_results: int = 5, collection_name: str = None, metadata_filter: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """
        Retrieve relevant information from the vector database with improved quality
        
        Args:
            query: The search query
            n_results: Number of results to return
            collection_name: Name of the collection to query
            metadata_filter: Optional metadata filter to apply to the query
            
        Returns:
            Dict[str, Any]: Retrieved information from the vector database
        """
        try:
            # Use provided collection_name or the one from environment variable
            if collection_name is None:
                collection_name = RAG_COLLECTION
                self.logger.info(f"Using collection from environment: {collection_name}")
            
            # Constants for retrieval quality
            RELEVANCE_FLOOR = 0.35  # Minimum relevance score to include a result
            N_CANDIDATES = max(n_results * 3, 12)  # Get more candidates than needed
            
            # Try enhanced retrieval first
            try:
                # Query the vector database with enhanced retrieval
                results = self.vector_db.query_collection(
                    query_text=query,
                    collection_name=collection_name,
                    n_results=N_CANDIDATES,
                    use_hybrid=True,  # Use hybrid search (vector + keyword)
                    use_mmr=True,     # Use MMR reranking for diversity
                    mmr_lambda=0.7    # Balance between relevance (0.7) and diversity (0.3)
                )
            except Exception as e:
                # If enhanced retrieval fails, fall back to standard retrieval
                self.logger.warning(f"Enhanced retrieval failed: {str(e)}. Falling back to standard retrieval.")
                results = self.vector_db.query_collection(
                    query_text=query,
                    collection_name=collection_name,
                    n_results=N_CANDIDATES
                )
            
            # Process the results
            documents = results.get("documents", [[]])[0]
            metadatas = results.get("metadatas", [[]])[0]
            distances = results.get("distances", [[]])[0]
            
            self.logger.info(f"Retrieved {len(documents)} candidate documents from {collection_name}")
            
            # Format the results with relevance scores
            formatted_results = []
            for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
                # Convert distance to relevance score
                relevance = 1.0 - distance if distance is not None else 0.8
                
                # Apply relevance floor - only include results above the threshold
                if relevance >= RELEVANCE_FLOOR:
                    formatted_results.append({
                        "content": doc,
                        "source": metadata.get("source", "Unknown"),
                        "chunk": metadata.get("chunk", i),
                        "relevance": relevance
                    })
            
            # Sort by relevance and limit to requested number (up to 6)
            formatted_results = sorted(formatted_results, key=lambda x: x["relevance"], reverse=True)
            formatted_results = formatted_results[:min(6, n_results)]
            
            self.logger.info(f"After filtering by relevance floor ({RELEVANCE_FLOOR}), {len(formatted_results)} documents remain")
            
            # If no results meet the relevance threshold, include a note
            if not formatted_results:
                self.logger.warning(f"No results met the relevance threshold of {RELEVANCE_FLOOR}")
                # Include the top result anyway if available, but mark it as low relevance
                if documents:
                    formatted_results.append({
                        "content": documents[0],
                        "source": metadatas[0].get("source", "Unknown") if metadatas else "Unknown",
                        "chunk": metadatas[0].get("chunk", 0) if metadatas else 0,
                        "relevance": 0.3,  # Mark as low relevance
                        "low_confidence": True
                    })
                    self.logger.info("Added top result with low confidence flag")
            
            # Log the sources being returned
            sources = [r["source"] for r in formatted_results]
            self.logger.info(f"Returning sources: {sources}")
            
            # Log the relevance scores
            relevance_scores = [f"{r['relevance']:.2f}" for r in formatted_results]
            self.logger.info(f"Relevance scores: {relevance_scores}")
            
            
            return {
                "results": formatted_results,
                "query": query,
                "collection": collection_name
            }
            
        except Exception as e:
            self.logger.error(f"Error retrieving information: {str(e)}", exc_info=True)
            return {
                "results": [],
                "query": query,
                "collection": collection_name,
                "error": str(e)
            }
