import time
import json
import os
from typing import Dict, Any, List, Optional
from datetime import datetime
import uuid

class RAGEvaluator:
    """
    Utility for evaluating RAG system performance and collecting feedback
    """
    
    def __init__(self, metrics_file: str = "rag_metrics.json"):
        """
        Initialize the RAG evaluator
        
        Args:
            metrics_file: File to store metrics data
        """
        self.metrics_file = metrics_file
        self.metrics_dir = os.path.join("logs", "metrics")
        os.makedirs(self.metrics_dir, exist_ok=True)
        self.metrics_path = os.path.join(self.metrics_dir, metrics_file)
        
    def log_query(self, query_data: Dict[str, Any]) -> str:
        """
        Log a query and its metrics
        
        Args:
            query_data: Data about the query and response
            
        Returns:
            str: ID of the logged query
        """
        # Generate a unique ID for this query
        query_id = str(uuid.uuid4())
        
        # Add timestamp and ID
        query_data["timestamp"] = datetime.now().isoformat()
        query_data["query_id"] = query_id
        
        # Load existing metrics
        metrics = self._load_metrics()
        
        # Add new query data
        if "queries" not in metrics:
            metrics["queries"] = []
        
        metrics["queries"].append(query_data)
        
        # Save updated metrics
        self._save_metrics(metrics)
        
        return query_id
    
    def log_feedback(self, query_id: str, feedback: Dict[str, Any]) -> bool:
        """
        Log user feedback for a specific query
        
        Args:
            query_id: ID of the query
            feedback: Feedback data
            
        Returns:
            bool: True if feedback was logged successfully
        """
        # Load existing metrics
        metrics = self._load_metrics()
        
        # Find the query and add feedback
        if "queries" in metrics:
            for query in metrics["queries"]:
                if query.get("query_id") == query_id:
                    # Add feedback with timestamp
                    feedback["timestamp"] = datetime.now().isoformat()
                    query["feedback"] = feedback
                    
                    # Save updated metrics
                    self._save_metrics(metrics)
                    return True
        
        return False
    
    def get_query_metrics(self, query_id: str) -> Optional[Dict[str, Any]]:
        """
        Get metrics for a specific query
        
        Args:
            query_id: ID of the query
            
        Returns:
            Optional[Dict[str, Any]]: Query metrics or None if not found
        """
        # Load metrics
        metrics = self._load_metrics()
        
        # Find the query
        if "queries" in metrics:
            for query in metrics["queries"]:
                if query.get("query_id") == query_id:
                    return query
        
        return None
    
    def get_aggregate_metrics(self, days: int = 7) -> Dict[str, Any]:
        """
        Get aggregate metrics for the specified number of days
        
        Args:
            days: Number of days to include in metrics
            
        Returns:
            Dict[str, Any]: Aggregate metrics
        """
        # Load metrics
        metrics = self._load_metrics()
        
        # Calculate cutoff date
        cutoff = datetime.now().timestamp() - (days * 24 * 60 * 60)
        
        # Filter queries by date
        recent_queries = []
        if "queries" in metrics:
            for query in metrics["queries"]:
                query_time = datetime.fromisoformat(query["timestamp"]).timestamp()
                if query_time >= cutoff:
                    recent_queries.append(query)
        
        # Calculate aggregate metrics
        total_queries = len(recent_queries)
        total_feedback = sum(1 for q in recent_queries if "feedback" in q)
        
        # Calculate average ratings if feedback exists
        avg_relevance = 0
        avg_correctness = 0
        avg_helpfulness = 0
        
        if total_feedback > 0:
            relevance_sum = sum(q["feedback"].get("relevance", 0) for q in recent_queries if "feedback" in q)
            correctness_sum = sum(q["feedback"].get("correctness", 0) for q in recent_queries if "feedback" in q)
            helpfulness_sum = sum(q["feedback"].get("helpfulness", 0) for q in recent_queries if "feedback" in q)
            
            avg_relevance = relevance_sum / total_feedback
            avg_correctness = correctness_sum / total_feedback
            avg_helpfulness = helpfulness_sum / total_feedback
        
        # Calculate average execution time
        avg_execution_time = 0
        if total_queries > 0:
            execution_times = [q.get("execution_time", 0) for q in recent_queries]
            avg_execution_time = sum(execution_times) / total_queries
        
        return {
            "total_queries": total_queries,
            "total_feedback": total_feedback,
            "feedback_rate": total_feedback / total_queries if total_queries > 0 else 0,
            "avg_relevance": avg_relevance,
            "avg_correctness": avg_correctness,
            "avg_helpfulness": avg_helpfulness,
            "avg_execution_time": avg_execution_time,
            "days": days
        }
    
    def _load_metrics(self) -> Dict[str, Any]:
        """Load metrics from file"""
        if os.path.exists(self.metrics_path):
            try:
                with open(self.metrics_path, 'r') as f:
                    return json.load(f)
            except Exception as e:
                print(f"Error loading metrics: {str(e)}")
        
        return {}
    
    def _save_metrics(self, metrics: Dict[str, Any]):
        """Save metrics to file"""
        try:
            with open(self.metrics_path, 'w') as f:
                json.dump(metrics, f, indent=2)
        except Exception as e:
            print(f"Error saving metrics: {str(e)}")
