from datetime import datetime, timezone
from sqlmodel import Session, select
from app.models.user_session import UserSession

class UserSessionRepository:

    @staticmethod
    def create(session: Session, user_session: UserSession) -> UserSession:
        session.add(user_session)
        session.commit()
        session.refresh(user_session)
        return user_session

    @staticmethod
    def get_by_refresh_token(session: Session, token: str) -> UserSession | None:
        stmt = select(UserSession).where(
            UserSession.refresh_token == token,
            UserSession.is_revoked == False,
            UserSession.expires_at > datetime.now(timezone.utc),
        )
        return session.exec(stmt).first()

    @staticmethod
    def get_active_by_session_id(session: Session, session_id: str) -> UserSession | None:
        stmt = select(UserSession).where(
            UserSession.session_id == session_id,
            UserSession.is_revoked == False,
            UserSession.expires_at > datetime.now(timezone.utc),
        )
        return session.exec(stmt).first()

    @staticmethod
    def revoke(
        session: Session,
        user_session: UserSession,
        replaced_by_token: str | None = None
    ) -> None:
        user_session.is_revoked = True
        user_session.replaced_by_token = replaced_by_token
        session.add(user_session)
        session.commit()

    @staticmethod
    def revoke_all_for_user(session: Session, user_id: int) -> None:
        stmt = select(UserSession).where(
            UserSession.user_id == user_id,
            UserSession.is_revoked == False,
        )
        sessions = session.exec(stmt).all()
        for s in sessions:
            s.is_revoked = True
            session.add(s)
        session.commit()

    @staticmethod
    def cleanup_user_sessions(session: Session, user_id: int) -> None:
        now = datetime.now(timezone.utc)
        # delete revoked sessions
        stmt1 = select(UserSession).where(
            UserSession.user_id == user_id,
            UserSession.is_revoked == True,
        )
        for s in session.exec(stmt1).all():
            session.delete(s)
        # delete expired sessions
        stmt2 = select(UserSession).where(
            UserSession.user_id == user_id,
            UserSession.expires_at < now,
        )
        for s in session.exec(stmt2).all():
            session.delete(s)
        session.commit()
