import asyncio
import json
import os
import time
from typing import Dict, Any, List, Optional
from dotenv import load_dotenv
import pprint
from colorama import Fore, Style, init

# Import our RAG components
from agents.query_understanding_agent import QueryUnderstandingAgent
from agents.rag_retrieval_agent import RAGRetrievalAgent
from agents.response_generation_agent import ResponseGenerationAgent
from utils.chroma_db import ChromaDBManager  # Fixed import name
from utils.llm_query_analyzer import LLMQueryAnalyzer
from utils.logger import app_logger

# Initialize colorama for colored terminal output
init()

# Load environment variables
load_dotenv()

# Set up the collection name
COLLECTION_NAME = "mangoit_docs_miniLM"

class SimpleRAGDebugger:
    """Simple debug tool to inspect the RAG pipeline process without external dependencies"""
    
    def __init__(self, collection_name: str = COLLECTION_NAME):
        """Initialize the RAG pipeline debugger
        
        Args:
            collection_name: Name of the ChromaDB collection to use
        """
        self.collection_name = collection_name
        
        # Initialize components
        self.query_understanding_agent = QueryUnderstandingAgent(logger=app_logger)
        self.rag_retrieval_agent = RAGRetrievalAgent(logger=app_logger)
        self.response_generation_agent = ResponseGenerationAgent(logger=app_logger)
        self.llm_query_analyzer = LLMQueryAnalyzer(logger=app_logger)
        
        # Initialize ChromaDB
        self.chroma_db = ChromaDBManager()  # Fixed class name
        
        # Check if collection exists
        self.check_collection()
    
    def check_collection(self):
        """Check if the collection exists and print stats"""
        try:
            # Use create_collection which handles both creating and getting collections
            collection = self.chroma_db.create_collection(self.collection_name)
            count = collection.count()
            print(f"{Fore.GREEN}✓ Collection '{self.collection_name}' found with {count} documents{Style.RESET_ALL}")
        except Exception as e:
            print(f"{Fore.RED}✗ Error accessing collection '{self.collection_name}': {str(e)}{Style.RESET_ALL}")
            raise
    
    async def debug_query(self, query: str, verbose: bool = True):
        """Debug a query through the entire RAG pipeline
        
        Args:
            query: The query to debug
            verbose: Whether to print verbose output
        """
        print(f"\n{Fore.CYAN}{'='*80}{Style.RESET_ALL}")
        print(f"{Fore.CYAN}DEBUGGING QUERY: {query}{Style.RESET_ALL}")
        print(f"{Fore.CYAN}{'='*80}{Style.RESET_ALL}\n")
        
        # Step 1: Query Understanding
        print(f"{Fore.YELLOW}STEP 1: QUERY UNDERSTANDING{Style.RESET_ALL}")
        query_analysis_start = time.time()
        query_analysis = await self._debug_query_understanding(query, verbose)
        query_analysis_time = time.time() - query_analysis_start
        print(f"{Fore.YELLOW}Query understanding completed in {query_analysis_time:.2f} seconds{Style.RESET_ALL}\n")
        
        # Step 2: RAG Retrieval
        print(f"{Fore.YELLOW}STEP 2: RAG RETRIEVAL{Style.RESET_ALL}")
        retrieval_start = time.time()
        retrieval_results = await self._debug_rag_retrieval(query, query_analysis, verbose)
        retrieval_time = time.time() - retrieval_start
        print(f"{Fore.YELLOW}RAG retrieval completed in {retrieval_time:.2f} seconds{Style.RESET_ALL}\n")
        
        # Step 3: Response Generation
        print(f"{Fore.YELLOW}STEP 3: RESPONSE GENERATION{Style.RESET_ALL}")
        response_start = time.time()
        response = await self._debug_response_generation(query, retrieval_results, verbose)
        response_time = time.time() - response_start
        print(f"{Fore.YELLOW}Response generation completed in {response_time:.2f} seconds{Style.RESET_ALL}\n")
        
        # Print final response
        print(f"{Fore.GREEN}{'='*80}{Style.RESET_ALL}")
        print(f"{Fore.GREEN}FINAL RESPONSE:{Style.RESET_ALL}")
        print(f"{Fore.WHITE}{response.get('response', 'No response generated')}{Style.RESET_ALL}")
        print(f"{Fore.GREEN}{'='*80}{Style.RESET_ALL}\n")
        
        # Print sources
        if response.get("sources"):
            print(f"{Fore.GREEN}SOURCES:{Style.RESET_ALL}")
            for i, source in enumerate(response.get("sources", [])):
                print(f"{i+1}. {source.get('source', 'Unknown')} (relevance: {source.get('relevance', 0):.2f})")
        
        # Print total time
        total_time = query_analysis_time + retrieval_time + response_time
        print(f"\n{Fore.CYAN}Total processing time: {total_time:.2f} seconds{Style.RESET_ALL}")
        
        return {
            "query_analysis": query_analysis,
            "retrieval_results": retrieval_results,
            "response": response,
            "timing": {
                "query_analysis": query_analysis_time,
                "retrieval": retrieval_time,
                "response_generation": response_time,
                "total": total_time
            }
        }
    
    async def _debug_query_understanding(self, query: str, verbose: bool = True) -> Dict[str, Any]:
        """Debug the query understanding step
        
        Args:
            query: The query to debug
            verbose: Whether to print verbose output
            
        Returns:
            Dict[str, Any]: The query analysis
        """
        # Get LLM query analysis
        print(f"{Fore.BLUE}Running LLM query analyzer...{Style.RESET_ALL}")
        llm_analysis = self.llm_query_analyzer.analyze_query(query)
        
        if verbose:
            print(f"{Fore.WHITE}LLM Query Analysis:{Style.RESET_ALL}")
            pprint.pprint(llm_analysis)
        
        # Run query understanding agent
        print(f"{Fore.BLUE}Running query understanding agent...{Style.RESET_ALL}")
        query_analysis = self.query_understanding_agent.analyze_query(query)
        
        if verbose:
            print(f"{Fore.WHITE}Query Understanding Agent Analysis:{Style.RESET_ALL}")
            pprint.pprint(query_analysis)
        
        return query_analysis
    
    async def _debug_rag_retrieval(self, query: str, query_analysis: Dict[str, Any], verbose: bool = True) -> List[Dict[str, Any]]:
        """
        Debug the RAG retrieval step
        
        Args:
            query: The original query
            query_analysis: The query analysis from the previous step
            verbose: Whether to print verbose output
        """
        # Prepare context for RAG retrieval
        context = {
            "query": query,
            "query_analysis": query_analysis,
            "collection_name": self.collection_name
        }
        
        # Run RAG retrieval (not async)
        print(f"{Fore.BLUE}Running RAG retrieval agent...{Style.RESET_ALL}")
        retrieval_result = self.rag_retrieval_agent.run(context)
        
        # Extract results
        results = retrieval_result.get("retrieved_info", {}).get("results", [])
        
        # Print retrieval stats
        print(f"{Fore.WHITE}Retrieved {len(results)} documents{Style.RESET_ALL}")
        
        if verbose and results:
            print(f"{Fore.WHITE}Top 3 retrieved documents:{Style.RESET_ALL}")
            for i, result in enumerate(results[:3]):
                print(f"\n{Fore.WHITE}Document {i+1}:{Style.RESET_ALL}")
                print(f"Source: {result.get('source', 'Unknown')}")
                print(f"Relevance: {result.get('relevance', 0):.2f}")
                print(f"Content preview: {result.get('content', '')[:200]}...")
        
        return results
    
    async def _debug_response_generation(self, query: str, retrieval_results: List[Dict[str, Any]], verbose: bool = True) -> Dict[str, Any]:
        """Debug the response generation step
        
        Args:
            query: The original query
            retrieval_results: The retrieval results from the previous step
            verbose: Whether to print verbose output
            
        Returns:
            Dict[str, Any]: The generated response
        """
        # Generate response
        print(f"{Fore.BLUE}Running response generation agent...{Style.RESET_ALL}")
        # Convert list to dict format expected by response_generation_agent
        retrieval_info = {"results": retrieval_results}
        # The generate_response method is not async, so don't use await
        response = self.response_generation_agent.generate_response(query, retrieval_info)
        
        if verbose and hasattr(self.response_generation_agent, 'last_prompt'):
            print(f"{Fore.WHITE}Response generation prompt preview:{Style.RESET_ALL}")
            # Extract and print a preview of the prompt used for response generation
            prompt_preview = self.response_generation_agent.last_prompt[:500] + "..." if len(self.response_generation_agent.last_prompt) > 500 else self.response_generation_agent.last_prompt
            print(prompt_preview)
        
        return response

async def main():
    """Run the RAG pipeline debugger"""
    # Create the debugger
    debugger = SimpleRAGDebugger()
    
    # Get query from command line or use default
    import sys
    if len(sys.argv) > 1:
        query = " ".join(sys.argv[1:])
    else:
        query = input("Enter your query: ")
    
    # Debug the query
    await debugger.debug_query(query)

if __name__ == "__main__":
    asyncio.run(main())
