import asyncio
import uuid
from typing import Dict, Any, Optional, Callable, Awaitable, List
import time
from utils.task_status import TaskStatus, AgentStatus, calculate_progress
from utils.logger import app_logger

class AsyncTaskManager:
    """
    Manages asynchronous tasks for the application.
    Allows for background processing of agent tasks without blocking the main thread.
    Tracks progress of tasks and provides status updates.
    """
    
    def __init__(self):
        self.tasks: Dict[str, Dict[str, Any]] = {}
        self.results: Dict[str, Any] = {}
        self.callbacks: Dict[str, Callable] = {}
        self.agent_statuses: Dict[str, List[AgentStatus]] = {}
        self.logger = app_logger
    
    async def create_task(self, coroutine, callback: Optional[Callable] = None, agent_names: Optional[List[str]] = None) -> str:
        """
        Create a new async task and return its ID
        
        Args:
            coroutine: The coroutine to execute
            callback: Optional callback function to execute when the task completes
            agent_names: Optional list of agent names involved in this task
            
        Returns:
            str: The task ID
        """
        task_id = str(uuid.uuid4())
        
        # Initialize agent statuses if provided
        if agent_names:
            self.agent_statuses[task_id] = []
            total_agents = len(agent_names)
            for agent_name in agent_names:
                self.agent_statuses[task_id].append(AgentStatus(agent_name, total_agents))
        
        # Create a wrapper coroutine that will store the result
        async def task_wrapper():
            try:
                start_time = time.time()
                self.logger.info(f"Task {task_id[:8]} started")
                
                result = await coroutine
                end_time = time.time()
                execution_time = end_time - start_time
                
                self.results[task_id] = {
                    "status": TaskStatus.COMPLETED,
                    "result": result,
                    "error": None,
                    "execution_time": execution_time,
                    "progress": 100  # Task completed, progress is 100%
                }
                
                self.logger.info(f"Task {task_id[:8]} completed in {execution_time:.2f}s")
                
                # Execute callback if provided
                if callback:
                    callback(result)
                    
            except asyncio.CancelledError:
                self.results[task_id] = {
                    "status": TaskStatus.CANCELLED,
                    "result": None,
                    "error": "Task was cancelled",
                    "progress": 0
                }
                self.logger.warning(f"Task {task_id[:8]} was cancelled")
                
            except Exception as e:
                self.results[task_id] = {
                    "status": TaskStatus.FAILED,
                    "result": None,
                    "error": str(e),
                    "progress": 0
                }
                self.logger.error(f"Task {task_id[:8]} failed: {str(e)}", exc_info=True)
        
        # Create and store the task
        task = asyncio.create_task(task_wrapper())
        self.tasks[task_id] = {
            "task": task,
            "status": TaskStatus.RUNNING,
            "created_at": time.time()
        }
        
        return task_id
    
    def get_task_status(self, task_id: str) -> Dict[str, Any]:
        """
        Get the status of a task
        
        Args:
            task_id: The ID of the task
            
        Returns:
            dict: The task status including progress information
        """
        # If we have final results, return them
        if task_id in self.results:
            result = self.results[task_id].copy()
            
            # Add agent statuses if available
            if task_id in self.agent_statuses:
                result["agent_statuses"] = [agent.to_dict() for agent in self.agent_statuses[task_id]]
                
            return result
        
        # If task is still running
        if task_id in self.tasks:
            status = TaskStatus.RUNNING if not self.tasks[task_id]["task"].done() else TaskStatus.COMPLETED
            
            # Calculate progress if agent statuses are available
            progress = 0
            if task_id in self.agent_statuses:
                agent_statuses = self.agent_statuses[task_id]
                progress = calculate_progress(agent_statuses)
                
            return {
                "status": status,
                "result": None,
                "error": None,
                "progress": progress,
                "agent_statuses": [agent.to_dict() for agent in self.agent_statuses[task_id]] if task_id in self.agent_statuses else []
            }
        
        # Task not found
        return {
            "status": TaskStatus.NOT_FOUND,
            "result": None,
            "error": "Task not found",
            "progress": 0
        }
    
    async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Dict[str, Any]:
        """
        Wait for a task to complete
        
        Args:
            task_id: The ID of the task
            timeout: Optional timeout in seconds
            
        Returns:
            dict: The task result
        """
        if task_id not in self.tasks:
            return {
                "status": TaskStatus.NOT_FOUND,
                "result": None,
                "error": "Task not found",
                "progress": 0
            }
        
        task = self.tasks[task_id]["task"]
        
        try:
            await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
        except asyncio.TimeoutError:
            self.logger.warning(f"Task {task_id[:8]} timed out after {timeout} seconds")
            return {
                "status": TaskStatus.TIMEOUT,
                "result": None,
                "error": f"Task timed out after {timeout} seconds",
                "progress": self.get_task_status(task_id).get("progress", 0)
            }
        
        return self.get_task_status(task_id)
    
    def cancel_task(self, task_id: str) -> bool:
        """
        Cancel a running task
        
        Args:
            task_id: The ID of the task
            
        Returns:
            bool: True if the task was cancelled, False otherwise
        """
        if task_id in self.tasks and not self.tasks[task_id]["task"].done():
            self.tasks[task_id]["task"].cancel()
            self.results[task_id] = {
                "status": TaskStatus.CANCELLED,
                "result": None,
                "error": "Task was cancelled",
                "progress": 0
            }
            self.logger.info(f"Task {task_id[:8]} cancelled")
            return True
        
        return False
    
    def cleanup_old_tasks(self, max_age: float = 3600.0) -> int:
        """
        Clean up old completed tasks
        
        Args:
            max_age: Maximum age of tasks in seconds (default: 1 hour)
            
        Returns:
            int: Number of tasks cleaned up
        """
        current_time = time.time()
        task_ids_to_remove = []
        
        for task_id, task_info in self.tasks.items():
            if task_info["task"].done() and (current_time - task_info["created_at"]) > max_age:
                task_ids_to_remove.append(task_id)
        
        for task_id in task_ids_to_remove:
            del self.tasks[task_id]
            if task_id in self.results:
                del self.results[task_id]
            if task_id in self.agent_statuses:
                del self.agent_statuses[task_id]
        
        if task_ids_to_remove:
            self.logger.info(f"Cleaned up {len(task_ids_to_remove)} old tasks")
        
        return len(task_ids_to_remove)

    def update_agent_status(self, task_id: str, agent_name: str, status: TaskStatus, error: Optional[str] = None) -> bool:
        """
        Update the status of an agent in a task
        
        Args:
            task_id: The ID of the task
            agent_name: The name of the agent
            status: The new status
            error: Optional error message
            
        Returns:
            bool: True if the status was updated, False otherwise
        """
        if task_id not in self.agent_statuses:
            self.logger.warning(f"Cannot update agent status: Task {task_id[:8]} not found")
            return False
            
        for agent in self.agent_statuses[task_id]:
            if agent.name == agent_name:
                # Update status
                agent.status = status
                
                # Update timing information
                current_time = time.time()
                if status == TaskStatus.RUNNING and agent.start_time is None:
                    agent.start_time = current_time
                elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and agent.end_time is None:
                    agent.end_time = current_time
                    if agent.start_time is not None:
                        agent.execution_time = agent.end_time - agent.start_time
                
                # Update error information
                if error:
                    agent.error = error
                    
                self.logger.info(f"Agent {agent_name} in task {task_id[:8]} status updated to {status}")
                return True
                
        self.logger.warning(f"Agent {agent_name} not found in task {task_id[:8]}")
        return False
