import asyncio
import json
import os
import time
from typing import Dict, Any, List, Optional
from dotenv import load_dotenv
import pprint
import argparse
import matplotlib.pyplot as plt
import numpy as np
from colorama import Fore, Style, init
from tabulate import tabulate

# Import our RAG components
from agents.query_understanding_agent import QueryUnderstandingAgent
from agents.rag_retrieval_agent import RAGRetrievalAgent
from agents.response_generation_agent import ResponseGenerationAgent
from utils.chroma_db import ChromaDB
from utils.llm_query_analyzer import LLMQueryAnalyzer
from utils.logger import app_logger

# Initialize colorama for colored terminal output
init()

# Load environment variables
load_dotenv()

class RAGVisualizer:
    """Visualization tool for RAG pipeline results"""
    
    def __init__(self, collection_name: str = "mangoit_docs_miniLM"):
        """Initialize the RAG visualizer
        
        Args:
            collection_name: Name of the ChromaDB collection to use
        """
        self.collection_name = collection_name
        
        # Initialize components
        self.query_understanding_agent = QueryUnderstandingAgent(logger=app_logger)
        self.rag_retrieval_agent = RAGRetrievalAgent(logger=app_logger)
        self.response_generation_agent = ResponseGenerationAgent(logger=app_logger)
        self.llm_query_analyzer = LLMQueryAnalyzer(logger=app_logger)
        
        # Initialize ChromaDB
        self.chroma_db = ChromaDB()
        
        # Check if collection exists
        self.check_collection()
    
    def check_collection(self):
        """Check if the collection exists and print stats"""
        try:
            collection = self.chroma_db.get_collection(self.collection_name)
            count = collection.count()
            print(f"{Fore.GREEN}✓ Collection '{self.collection_name}' found with {count} documents{Style.RESET_ALL}")
        except Exception as e:
            print(f"{Fore.RED}✗ Error accessing collection '{self.collection_name}': {str(e)}{Style.RESET_ALL}")
            raise
    
    async def process_query(self, query: str) -> Dict[str, Any]:
        """Process a query through the RAG pipeline
        
        Args:
            query: The query to process
            
        Returns:
            Dict[str, Any]: The results of the RAG pipeline
        """
        print(f"\n{Fore.CYAN}Processing query: {query}{Style.RESET_ALL}")
        
        # Step 1: Query Understanding
        query_analysis = self.query_understanding_agent.analyze_query(query)
        
        # Step 2: RAG Retrieval
        context = {
            "query": query,
            "query_analysis": query_analysis,
            "collection_name": self.collection_name
        }
        retrieval_result = await self.rag_retrieval_agent.run(context)
        results = retrieval_result.get("results", [])
        
        # Step 3: Response Generation
        response = await self.response_generation_agent.generate_response(query, results)
        
        return {
            "query": query,
            "query_analysis": query_analysis,
            "retrieval_results": results,
            "response": response
        }
    
    def visualize_query_analysis(self, query_analysis: Dict[str, Any]):
        """Visualize the query analysis
        
        Args:
            query_analysis: The query analysis to visualize
        """
        print(f"\n{Fore.YELLOW}{'='*80}{Style.RESET_ALL}")
        print(f"{Fore.YELLOW}QUERY ANALYSIS{Style.RESET_ALL}")
        print(f"{Fore.YELLOW}{'='*80}{Style.RESET_ALL}")
        
        # Print intent and confidence
        intent = query_analysis.get("intent", "unknown")
        confidence = query_analysis.get("confidence", 0.0)
        print(f"{Fore.WHITE}Intent: {intent} (confidence: {confidence:.2f}){Style.RESET_ALL}")
        
        # Print keywords
        keywords = query_analysis.get("keywords", [])
        print(f"{Fore.WHITE}Keywords: {', '.join(keywords)}{Style.RESET_ALL}")
        
        # Print search queries
        search_queries = query_analysis.get("search_queries", [])
        if search_queries:
            print(f"{Fore.WHITE}Search Queries:{Style.RESET_ALL}")
            for i, query in enumerate(search_queries):
                print(f"{Fore.WHITE}{i+1}. {query}{Style.RESET_ALL}")
        
        # Print multi-queries
        multi_queries = query_analysis.get("multi_queries", [])
        if multi_queries and multi_queries != search_queries:
            print(f"{Fore.WHITE}Multi-Queries:{Style.RESET_ALL}")
            for i, query in enumerate(multi_queries):
                print(f"{Fore.WHITE}{i+1}. {query}{Style.RESET_ALL}")
        
        # Print search plan
        search_plan = query_analysis.get("search_plan", {})
        if search_plan:
            print(f"{Fore.WHITE}Search Plan:{Style.RESET_ALL}")
            print(f"{Fore.WHITE}{json.dumps(search_plan, indent=2)}{Style.RESET_ALL}")
    
    def visualize_retrieval_results(self, results: List[Dict[str, Any]]):
        """Visualize the retrieval results
        
        Args:
            results: The retrieval results to visualize
        """
        print(f"\n{Fore.YELLOW}{'='*80}{Style.RESET_ALL}")
        print(f"{Fore.YELLOW}RETRIEVAL RESULTS{Style.RESET_ALL}")
        print(f"{Fore.YELLOW}{'='*80}{Style.RESET_ALL}")
        
        # Print number of results
        print(f"{Fore.WHITE}Retrieved {len(results)} documents{Style.RESET_ALL}")
        
        # Create a table of results
        table_data = []
        for i, result in enumerate(results[:10]):  # Show top 10 results
            source = result.get("source", "Unknown")
            relevance = result.get("relevance", 0)
            content_preview = result.get("content", "")[:100] + "..." if len(result.get("content", "")) > 100 else result.get("content", "")
            table_data.append([i+1, source, f"{relevance:.2f}", content_preview])
        
        # Print the table
        headers = ["#", "Source", "Relevance", "Content Preview"]
        print(tabulate(table_data, headers=headers, tablefmt="grid"))
        
        # Plot relevance scores
        if results:
            relevance_scores = [result.get("relevance", 0) for result in results[:10]]
            sources = [os.path.basename(result.get("source", "Unknown")) for result in results[:10]]
            
            plt.figure(figsize=(10, 6))
            bars = plt.bar(range(len(relevance_scores)), relevance_scores, color='skyblue')
            plt.xlabel('Document')
            plt.ylabel('Relevance Score')
            plt.title('Relevance Scores of Retrieved Documents')
            plt.xticks(range(len(sources)), sources, rotation=45, ha='right')
            plt.tight_layout()
            
            # Save the plot
            plt.savefig("relevance_scores.png")
            print(f"{Fore.GREEN}Relevance scores plot saved to relevance_scores.png{Style.RESET_ALL}")
    
    def visualize_response(self, response: Dict[str, Any]):
        """Visualize the response
        
        Args:
            response: The response to visualize
        """
        print(f"\n{Fore.YELLOW}{'='*80}{Style.RESET_ALL}")
        print(f"{Fore.YELLOW}RESPONSE{Style.RESET_ALL}")
        print(f"{Fore.YELLOW}{'='*80}{Style.RESET_ALL}")
        
        # Print the response
        print(f"{Fore.WHITE}{response.get('response', 'No response generated')}{Style.RESET_ALL}")
        
        # Print sources
        sources = response.get("sources", [])
        if sources:
            print(f"\n{Fore.WHITE}Sources:{Style.RESET_ALL}")
            for i, source in enumerate(sources):
                print(f"{Fore.WHITE}{i+1}. {source.get('source', 'Unknown')} (relevance: {source.get('relevance', 0):.2f}){Style.RESET_ALL}")

async def main():
    """Run the RAG visualizer"""
    parser = argparse.ArgumentParser(description="Visualize RAG pipeline results")
    parser.add_argument("query", nargs="?", default=None, help="Query to process")
    parser.add_argument("--collection", default="mangoit_docs_miniLM", help="ChromaDB collection name")
    args = parser.parse_args()
    
    # Create the visualizer
    visualizer = RAGVisualizer(collection_name=args.collection)
    
    # Get query from command line or use default
    query = args.query
    if not query:
        query = input("Enter your query: ")
    
    # Process the query
    results = await visualizer.process_query(query)
    
    # Visualize the results
    visualizer.visualize_query_analysis(results["query_analysis"])
    visualizer.visualize_retrieval_results(results["retrieval_results"])
    visualizer.visualize_response(results["response"])

if __name__ == "__main__":
    asyncio.run(main())
