import os
import glob
import time
import json
import hashlib
import chromadb
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
import uuid
from utils.vector_db_interface import VectorDBInterface
from typing import Dict, Any, List, Optional, Tuple
from dotenv import load_dotenv

load_dotenv()

class ChromaDBManager(VectorDBInterface):
    def __init__(self, persist_directory="chroma_db", collection_name=None):
        """
        Initialize ChromaDB Manager
        
        Args:
            persist_directory: Directory to persist ChromaDB
            collection_name: Default collection name to use
        """
        self.persist_directory = persist_directory
        
        # Create the directory if it doesn't exist
        os.makedirs(self.persist_directory, exist_ok=True)
        
        # Initialize ChromaDB client
        self.client = chromadb.PersistentClient(path=self.persist_directory)
        
        # Set default collection name from environment variable or use provided value
        self.collection_name = collection_name or os.getenv("RAG_COLLECTION", "mangoit_docs")
        
        # Read embedding model from environment variable
        self.embed_backend = os.getenv("EMBED_BACKEND", "all-MiniLM-L6-v2")
        self.gemini_model = os.getenv("GEMINI_MODEL", "gemini-embedding-001")
        
        # Initialize embedding function based on backend
        if self.embed_backend == "gemini":
            self.embedding_function = self._create_gemini_embedding_function()
        else:
            self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
                model_name="all-MiniLM-L6-v2"
            )
        
        # State file for incremental processing
        self._state_path = os.path.join(self.persist_directory, ".md_index_state.json")
        
    def _create_gemini_embedding_function(self):
        """Create a Gemini embedding function adapter"""
        try:
            import google.generativeai as genai
            
            api_key = os.getenv("GEMINI_API_KEY")
            if not api_key:
                print("Warning: GEMINI_API_KEY not found, falling back to default embeddings")
                return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
            
            genai.configure(api_key=api_key)
            
            class GeminiEmbeddingFunction:
                def __init__(self, model_name):
                    self.model_name = model_name
                
                def __call__(self, texts):
                    """Generate embeddings for a list of texts"""
                    results = []
                    for text in texts:
                        try:
                            embedding = genai.embed_content(
                                model=self.model_name,
                                content=text,
                                task_type="retrieval_document"
                            )
                            if embedding and hasattr(embedding, "embedding"):
                                results.append(embedding.embedding)
                            else:
                                # Fallback to zeros if embedding fails
                                results.append([0.0] * 768)
                        except Exception as e:
                            print(f"Error generating Gemini embedding: {str(e)}")
                            # Fallback to zeros if embedding fails
                            results.append([0.0] * 768)
                    return results
            
            return GeminiEmbeddingFunction(self.gemini_model)
        
        except ImportError:
            print("Warning: google.generativeai not installed, falling back to default embeddings")
            return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
        
    def _sha256(self, text: str) -> str:
        """Generate SHA256 hash of text"""
        return hashlib.sha256(text.encode('utf-8')).hexdigest()

    def _deterministic_id(self, source: str, chunk_idx: int, chunk_text: str) -> str:
        """Generate deterministic ID for a chunk"""
        chunk_hash = self._sha256(chunk_text)
        return f"{source}:::{chunk_idx}:::{chunk_hash}"

    def _delete_source(self, collection, source: str):
        """Delete all documents from a specific source"""
        try:
            collection.delete(where={"source": source})
            print(f"Deleted existing documents from source: {source}")
        except Exception as e:
            print(f"Error deleting documents from source {source}: {str(e)}")
            
    def _load_state(self) -> Dict[str, Any]:
        """Load the state file for incremental processing"""
        if os.path.exists(self._state_path):
            try:
                with open(self._state_path, 'r') as f:
                    return json.load(f)
            except Exception as e:
                print(f"Error loading state file: {str(e)}")
        return {"files": {}, "last_indexed": 0}
    
    def _save_state(self, state: Dict[str, Any]):
        """Save the state file for incremental processing"""
        try:
            with open(self._state_path, 'w') as f:
                json.dump(state, f)
        except Exception as e:
            print(f"Error saving state file: {str(e)}")
    
    def _split_markdown(self, md_text: str, chunk_size: int = 1200, chunk_overlap: int = 180) -> List[Tuple[str, Dict[str, Any]]]:
        """Split markdown text with awareness of headers and structure
        
        Returns:
            List of tuples (chunk_text, header_metadata)
        """
        # First split by headers
        headers_to_split_on = [
            ("#", "header1"),
            ("##", "header2"),
            ("###", "header3"),
        ]
        
        try:
            md_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
            md_header_splits = md_splitter.split_text(md_text)
            
            # Then split by chunks
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=len
            )
            
            chunks_with_metadata = []
            for doc in md_header_splits:
                # Extract header info for metadata
                header_info = {}
                for key, value in doc.metadata.items():
                    if key.startswith("header"):
                        header_info[key] = value
                
                # Split the content into smaller chunks
                sub_chunks = text_splitter.split_text(doc.page_content)
                
                # Add each chunk with its header info
                for chunk in sub_chunks:
                    chunks_with_metadata.append((chunk, header_info))
            
            return chunks_with_metadata
        except Exception as e:
            print(f"Error in markdown splitting: {str(e)}")
            # Fallback to simple chunking
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=len
            )
            chunks = text_splitter.split_text(md_text)
            return [(chunk, {}) for chunk in chunks]
    
    def _extract_metadata(self, file_path: str, rel_path: str, chunk_text: str, chunk_idx: int, headers: Dict[str, str] = None) -> Dict[str, Any]:
        """Extract rich metadata from file and chunk"""
        # Basic metadata
        metadata = {
            "source": rel_path,
            "chunk": chunk_idx,
            "filename": os.path.basename(file_path),
            "extension": os.path.splitext(file_path)[1].lower(),
            "directory": os.path.dirname(rel_path),
        }
        
        # Add header information if available
        if headers:
            for key, value in headers.items():
                metadata[key] = value
        
        # Extract content type
        if "post" in rel_path.lower():
            metadata["content_type"] = "post"
        elif "page" in rel_path.lower():
            metadata["content_type"] = "page"
        else:
            metadata["content_type"] = "document"
        
        # Extract date if available in filename (common for blog posts)
        import re
        date_match = re.search(r'(\d{4}[-/]\d{2}[-/]\d{2})', rel_path)
        if date_match:
            metadata["date"] = date_match.group(1)
        
        return metadata
    
    def _mmr_pick(self, query_embedding, embeddings, documents, k: int, lambda_mult: float = 0.5):
        """Apply Maximum Marginal Relevance to select diverse results"""
        try:
            import numpy as np
            
            if len(documents) <= k:
                return list(range(len(documents)))
            
            # Convert embeddings to numpy arrays
            query_embedding = np.array(query_embedding)
            embeddings = np.array(embeddings)
            
            # Calculate similarity between query and documents
            query_sim = np.dot(embeddings, query_embedding)
            
            # Initialize
            selected_indices = []
            remaining_indices = list(range(len(embeddings)))
            
            # Select first document with highest similarity to query
            first_idx = np.argmax(query_sim)
            selected_indices.append(first_idx)
            remaining_indices.remove(first_idx)
            
            # Select remaining documents
            for _ in range(min(k - 1, len(documents) - 1)):
                # Calculate similarity to query for remaining documents
                query_similarity = query_sim[remaining_indices]
                
                # Calculate similarity to already selected documents
                doc_similarity = np.max(
                    [np.dot(embeddings[remaining_indices], embeddings[idx]) 
                     for idx in selected_indices],
                    axis=0
                )
                
                # Apply MMR formula
                mmr_scores = lambda_mult * query_similarity - (1 - lambda_mult) * doc_similarity
                mmr_idx = np.argmax(mmr_scores)
                
                # Add to selected
                selected_idx = remaining_indices[mmr_idx]
                selected_indices.append(selected_idx)
                remaining_indices.remove(selected_idx)
            
            return selected_indices
        except Exception as e:
            print(f"Error in MMR: {str(e)}")
            # Fallback to first k documents
            return list(range(min(k, len(documents))))
    
    def create_collection(self, collection_name: str = "mangoit_docs") -> Any:
        """
        Create or get a collection
        
        Args:
            collection_name: Name of the collection
            
        Returns:
            The ChromaDB collection
        """
        try:
            # Try to get the collection if it exists
            collection = self.client.get_collection(
                name=collection_name,
                embedding_function=self.embedding_function
            )
            print(f"Collection '{collection_name}' already exists with {collection.count()} documents")
        except:
            # Create a new collection if it doesn't exist
            collection = self.client.create_collection(
                name=collection_name,
                embedding_function=self.embedding_function
            )
            print(f"Created new collection '{collection_name}'")
        
        return collection
    
    def embed_markdown_files(self, markdown_dir: str, collection_name: str = "mangoit_docs", 
                          chunk_size: int = 1200, chunk_overlap: int = 180) -> int:
        """
        Embed markdown files into ChromaDB with improved chunking and incremental updates
        
        Args:
            markdown_dir: Directory containing markdown files
            collection_name: Name of the collection
            chunk_size: Size of text chunks
            chunk_overlap: Overlap between chunks
            
        Returns:
            Number of documents embedded
        """
        # Get or create collection
        collection = self.create_collection(collection_name)
        
        # Get all markdown files
        markdown_files = glob.glob(os.path.join(markdown_dir, "**/*.md"), recursive=True)
        
        if not markdown_files:
            print(f"No markdown files found in {markdown_dir}")
            return 0
        
        print(f"Found {len(markdown_files)} markdown files")
        
        # Load state for incremental processing
        state = self._load_state()
        file_states = state.get("files", {})
        
        # Process each markdown file
        total_chunks = 0
        new_chunks = 0
        skipped_files = 0
        
        for file_path in markdown_files:
            try:
                # Get relative path for metadata
                rel_path = os.path.relpath(file_path, markdown_dir)
                
                # Check file modification time
                mtime = os.path.getmtime(file_path)
                
                # Load the file content
                with open(file_path, 'r', encoding='utf-8') as f:
                    content = f.read()
                
                # Calculate file hash
                file_hash = self._sha256(content)
                
                # Check if file has changed
                if rel_path in file_states and file_states[rel_path].get("hash") == file_hash:
                    skipped_files += 1
                    total_chunks += file_states[rel_path].get("chunks", 0)
                    print(f"Skipping unchanged file: {rel_path}")
                    continue
                
                # Delete existing documents for this source
                self._delete_source(collection, rel_path)
                
                # Split text into chunks using markdown-aware splitting
                chunks_with_metadata = self._split_markdown(content, chunk_size, chunk_overlap)
                
                # Prepare data for ChromaDB
                documents = []
                ids = []
                metadatas = []
                
                for i, (chunk, headers) in enumerate(chunks_with_metadata):
                    # Generate deterministic ID
                    chunk_id = self._deterministic_id(rel_path, i, chunk)
                    
                    # Extract metadata
                    metadata = self._extract_metadata(file_path, rel_path, chunk, i, headers)
                    
                    documents.append(chunk)
                    ids.append(chunk_id)
                    metadatas.append(metadata)
                
                # Add chunks to collection
                if documents:
                    collection.upsert(
                        documents=documents,
                        ids=ids,
                        metadatas=metadatas
                    )
                
                # Update state
                file_states[rel_path] = {
                    "hash": file_hash,
                    "mtime": mtime,
                    "chunks": len(documents),
                    "last_indexed": time.time()
                }
                
                new_chunks += len(documents)
                total_chunks += len(documents)
                print(f"Embedded {len(documents)} chunks from {rel_path}")
                
            except Exception as e:
                print(f"Error processing {file_path}: {str(e)}")
        
        # Save updated state
        state["files"] = file_states
        state["last_indexed"] = time.time()
        self._save_state(state)
        
        print(f"Total chunks in database: {total_chunks}")
        print(f"New chunks added: {new_chunks}")
        print(f"Files skipped (unchanged): {skipped_files}")
        
        return total_chunks
    
    def query_collection(self, query_text: str, collection_name: str = "mangoit_docs", n_results: int = 5, use_mmr: bool = True, lambda_mult: float = 0.5) -> Dict[str, Any]:
        """
        Query the collection with improved diversity using MMR
        
        Args:
            query_text: Query text
            collection_name: Name of the collection
            n_results: Number of results to return
            use_mmr: Whether to use MMR for diverse results
            lambda_mult: Diversity parameter for MMR (0 = max diversity, 1 = max relevance)
            
        Returns:
            Query results
        """
        collection = self.client.get_collection(
            name=collection_name,
            embedding_function=self.embedding_function
        )
        
        # Get more candidates than needed if using MMR
        n_candidates = n_results * 3 if use_mmr else n_results
        n_candidates = min(n_candidates, 20)  # Limit to avoid excessive processing
        
        # Query the collection
        try:
            # Try with include_embeddings parameter (newer ChromaDB versions)
            results = collection.query(
                query_texts=[query_text],
                n_results=n_candidates,
                include_embeddings=use_mmr  # Include embeddings if using MMR
            )
        except TypeError:
            # Fallback for older ChromaDB versions that don't support include_embeddings
            results = collection.query(
                query_texts=[query_text],
                n_results=n_candidates
            )
            # MMR won't work without embeddings in older versions
            use_mmr = False
        
        if not use_mmr or n_results >= len(results.get("documents", [[]])[0]):
            # If not using MMR or we have fewer results than requested, return as is
            if not use_mmr and n_candidates > n_results:
                # Trim results to requested number
                return {
                    "ids": [results["ids"][0][:n_results]],
                    "distances": [results["distances"][0][:n_results]],
                    "metadatas": [results["metadatas"][0][:n_results]],
                    "documents": [results["documents"][0][:n_results]],
                }
            return results
        
        # Apply MMR to select diverse results
        try:
            if "embeddings" in results and results["embeddings"]:
                # Get document embeddings from results
                doc_embeddings = results["embeddings"][0]
                
                # Get query embedding
                query_embedding = self.embedding_function([query_text])[0]
                
                # Apply MMR
                selected_indices = self._mmr_pick(
                    query_embedding=query_embedding,
                    embeddings=doc_embeddings,
                    documents=results.get("documents", [[]])[0],
                    k=n_results,
                    lambda_mult=lambda_mult
                )
                
                # Filter results
                filtered_results = {
                    "ids": [[results["ids"][0][i] for i in selected_indices]],
                    "distances": [[results["distances"][0][i] for i in selected_indices]],
                    "metadatas": [[results["metadatas"][0][i] for i in selected_indices]],
                    "documents": [[results["documents"][0][i] for i in selected_indices]],
                }
                
                return filtered_results
            else:
                # Fallback if embeddings not available
                return {
                    "ids": [results["ids"][0][:n_results]],
                    "distances": [results["distances"][0][:n_results]],
                    "metadatas": [results["metadatas"][0][:n_results]],
                    "documents": [results["documents"][0][:n_results]],
                }
        
        except Exception as e:
            print(f"Error applying MMR, returning standard results: {str(e)}")
            # Fallback to standard results
            return {
                "ids": [results["ids"][0][:n_results]],
                "distances": [results["distances"][0][:n_results]],
                "metadatas": [results["metadatas"][0][:n_results]],
                "documents": [results["documents"][0][:n_results]],
            }
        
    def _get_collection(self, collection_name: str = None) -> Any:
        """
        Get a collection by name
        
        Args:
            collection_name: Name of the collection to get (uses default if None)
            
        Returns:
            The collection object
        """
        if collection_name is None:
            collection_name = self.collection_name
        
        try:
            # Try to get the collection directly
            return self.client.get_collection(
                name=collection_name,
                embedding_function=self.embedding_function
            )
        except Exception as e:
            # If collection doesn't exist, create it
            if "does not exist" in str(e).lower():
                try:
                    collection = self.client.create_collection(
                        name=collection_name,
                        embedding_function=self.embedding_function
                    )
                    print(f"Created new collection: {collection_name}")
                    return collection
                except Exception as create_error:
                    # Handle case where collection was created in another thread
                    if "already exists" in str(create_error).lower():
                        return self.client.get_collection(
                            name=collection_name,
                            embedding_function=self.embedding_function
                        )
                    else:
                        print(f"Error creating collection: {str(create_error)}")
            else:
                print(f"Error getting collection: {str(e)}")
            
            # Return default collection as fallback
            try:
                return self.client.get_or_create_collection(
                    name=self.collection_name,
                    embedding_function=self.embedding_function
                )
            except Exception:
                # Last resort fallback
                return self.client.get_or_create_collection(
                    name="mangoit_docs",
                    embedding_function=self.embedding_function
                )
    
    def query_collection(self, query_text: str, collection_name: str = None, n_results: int = 5, use_hybrid: bool = False, use_mmr: bool = False, mmr_lambda: float = 0.5) -> Dict[str, Any]:
        """
        Query a collection with a text query using hybrid search and MMR reranking
        
        Args:
            query_text: The text to search for
            collection_name: The name of the collection to query
            n_results: The number of results to return
            use_hybrid: Whether to use hybrid search (vector + keyword)
            use_mmr: Whether to use MMR reranking for diversity
            mmr_lambda: MMR lambda parameter (0-1), higher values favor relevance over diversity
            
        Returns:
            Dict[str, Any]: The query results
        """
        try:
            # Get the collection
            collection = self._get_collection(collection_name)
            
            # Generate embeddings for the query
            query_embedding = self.embedding_function([query_text])
            
            # Determine the number of candidates to retrieve
            # For MMR, we need more candidates than final results
            n_candidates = max(n_results * 3, 20) if use_mmr else n_results
            
            # Query parameters
            query_params = {}
            
            # For hybrid search, just use query_texts without alpha parameter
            # This works with older versions of ChromaDB
            if use_hybrid:
                # Add keyword search parameter (text query)
                query_params["query_texts"] = [query_text]
            
            # Query the collection
            results = collection.query(
                query_embeddings=query_embedding,
                n_results=n_candidates,
                include=["documents", "metadatas", "distances", "embeddings"],
                **query_params
            )
            
            # Apply MMR reranking if requested and embeddings are available
            try:
                if use_mmr and "embeddings" in results and results["embeddings"] and results["embeddings"][0]:
                    # Get document embeddings
                    doc_embeddings = results["embeddings"][0]
                    
                    # Apply MMR to select diverse results
                    mmr_indices = self._mmr(query_embedding[0], doc_embeddings, k=n_results, lambda_param=mmr_lambda)
                    
                    # Reorder results based on MMR
                    for key in ["documents", "metadatas", "distances", "embeddings"]:
                        if key in results and results[key] and results[key][0]:
                            results[key][0] = [results[key][0][i] for i in mmr_indices]
            except Exception as e:
                print(f"MMR reranking failed, using standard results: {str(e)}")
                # If MMR fails, just use the top results as they are
            
            # Remove embeddings from results to save memory
            if "embeddings" in results:
                del results["embeddings"]
            
            return results
            
        except Exception as e:
            print(f"Error querying collection: {str(e)}")
            return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
    
    def get_collection_info(self, collection_name: str = "mangoit_docs") -> Dict[str, Any]:
        """
        Get information about a collection
        
        Args:
            collection_name: Name of the collection
            
        Returns:
            Collection information
        """
        try:
            collection = self.client.get_collection(
                name=collection_name,
                embedding_function=self.embedding_function
            )
            
            return {
                "name": collection_name,
                "count": collection.count(),
                "exists": True
            }
        except Exception as e:
            return {
                "name": collection_name,
                "count": 0,
                "exists": False,
                "error": str(e)
            }
