99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session as DBSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.config import settings
|
|
from app.core.security import verify_password
|
|
from app.db.session import get_db
|
|
from app.db.models import User, Session
|
|
|
|
router = APIRouter()
|
|
|
|
COOKIE_NAME = os.getenv("SESSION_COOKIE_NAME", "ai_chat_session")
|
|
SESSION_TTL_HOURS = int(os.getenv("SESSION_TTL_HOURS", "168"))
|
|
|
|
class LoginRequest(BaseModel):
|
|
login: str
|
|
password: str
|
|
|
|
class LoginResponse(BaseModel):
|
|
status: str
|
|
|
|
class MeResponse(BaseModel):
|
|
login: str
|
|
|
|
@router.post("/login", response_model=LoginResponse)
|
|
def login(login_data: LoginRequest, response: Response, db: DBSession = Depends(get_db)):
|
|
user = db.scalar(select(User).where(User.login == login_data.login))
|
|
if not user or not user.is_active:
|
|
raise HTTPException(status_code=401, detail="Invalid credentials or inactive user")
|
|
|
|
if not verify_password(login_data.password, user.hashed_password):
|
|
raise HTTPException(status_code=401, detail="Invalid credentials or inactive user")
|
|
|
|
# Create session
|
|
expires = datetime.now(timezone.utc) + timedelta(hours=SESSION_TTL_HOURS)
|
|
# Strip timezone for naive datetime storage if DB expects it, depending on pg setup. Let's use naive UTC
|
|
expires_naive = expires.replace(tzinfo=None)
|
|
|
|
db_session = Session(
|
|
user_id=user.id,
|
|
expires_at=expires_naive
|
|
)
|
|
db.add(db_session)
|
|
db.commit()
|
|
db.refresh(db_session)
|
|
|
|
# Set cookie
|
|
is_secure = os.getenv("SESSION_COOKIE_SECURE", "false").lower() == "true"
|
|
samesite = os.getenv("SESSION_COOKIE_SAMESITE", "lax").lower()
|
|
|
|
response.set_cookie(
|
|
key=COOKIE_NAME,
|
|
value=db_session.id,
|
|
httponly=True,
|
|
secure=is_secure,
|
|
samesite=samesite,
|
|
max_age=SESSION_TTL_HOURS * 3600
|
|
)
|
|
|
|
return {"status": "ok"}
|
|
|
|
@router.post("/logout", response_model=LoginResponse)
|
|
def logout(request: Request, response: Response, db: DBSession = Depends(get_db)):
|
|
session_id = request.cookies.get(COOKIE_NAME)
|
|
if session_id:
|
|
db_session = db.get(Session, session_id)
|
|
if db_session:
|
|
db.delete(db_session)
|
|
db.commit()
|
|
|
|
response.delete_cookie(key=COOKIE_NAME)
|
|
return {"status": "ok"}
|
|
|
|
@router.get("/me", response_model=MeResponse)
|
|
def me(request: Request, db: DBSession = Depends(get_db)):
|
|
session_id = request.cookies.get(COOKIE_NAME)
|
|
if not session_id:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
|
|
db_session = db.get(Session, session_id)
|
|
if not db_session:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid session")
|
|
|
|
if db_session.expires_at < datetime.now(timezone.utc).replace(tzinfo=None):
|
|
db.delete(db_session)
|
|
db.commit()
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Session expired")
|
|
|
|
user = db_session.user
|
|
if not user or not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User inactive")
|
|
|
|
return {"login": user.login}
|