from uuid import uuid4 from fastapi import HTTPException, status from sqlalchemy import desc, func, select from app.db.session import AsyncSessionLocal from app.models.scheme import SchemeRecord from app.models.scheme_version import SchemeVersionRecord from app.services.api_errors import raise_conflict def _raise_current_version_inconsistent(*, scheme_id: str, current_version_number: int) -> None: raise_conflict( code="current_version_inconsistent", message="Scheme current version pointer is inconsistent with scheme_versions state.", details={ "scheme_id": scheme_id, "current_version_number": current_version_number, }, ) async def _get_scheme_for_update(session, scheme_id: str) -> SchemeRecord: scheme_result = await session.execute( select(SchemeRecord) .where(SchemeRecord.scheme_id == scheme_id) .with_for_update() ) scheme = scheme_result.scalar_one_or_none() if scheme is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Scheme not found", ) return scheme async def _get_current_version_for_scheme(session, scheme: SchemeRecord) -> SchemeVersionRecord: version_result = await session.execute( select(SchemeVersionRecord) .where( SchemeVersionRecord.scheme_id == scheme.scheme_id, SchemeVersionRecord.version_number == scheme.current_version_number, ) .with_for_update() ) version = version_result.scalar_one_or_none() if version is None: _raise_current_version_inconsistent( scheme_id=scheme.scheme_id, current_version_number=scheme.current_version_number, ) return version async def create_scheme_from_upload( *, source_upload_id: str, name: str, normalized_elements_count: int, normalized_seats_count: int, normalized_groups_count: int, normalized_sectors_count: int, ) -> str: scheme_id = uuid4().hex async with AsyncSessionLocal() as session: row = SchemeRecord( scheme_id=scheme_id, source_upload_id=source_upload_id, name=name, status="draft", current_version_number=1, normalized_elements_count=normalized_elements_count, normalized_seats_count=normalized_seats_count, normalized_groups_count=normalized_groups_count, normalized_sectors_count=normalized_sectors_count, ) session.add(row) await session.commit() return scheme_id async def create_scheme_from_upload_with_initial_version( *, source_upload_id: str, name: str, normalized_storage_path: str, normalized_elements_count: int, normalized_seats_count: int, normalized_groups_count: int, normalized_sectors_count: int, display_svg_storage_path: str | None = None, display_svg_status: str = "pending", display_svg_generated_at=None, ) -> tuple[str, str]: scheme_id = uuid4().hex scheme_version_id = uuid4().hex async with AsyncSessionLocal() as session: scheme = SchemeRecord( scheme_id=scheme_id, source_upload_id=source_upload_id, name=name, status="draft", current_version_number=1, normalized_elements_count=normalized_elements_count, normalized_seats_count=normalized_seats_count, normalized_groups_count=normalized_groups_count, normalized_sectors_count=normalized_sectors_count, ) version = SchemeVersionRecord( scheme_version_id=scheme_version_id, scheme_id=scheme_id, version_number=1, status="draft", normalized_storage_path=normalized_storage_path, normalized_elements_count=normalized_elements_count, normalized_seats_count=normalized_seats_count, normalized_groups_count=normalized_groups_count, normalized_sectors_count=normalized_sectors_count, display_svg_storage_path=display_svg_storage_path, display_svg_status=display_svg_status, display_svg_generated_at=display_svg_generated_at, ) session.add(scheme) session.add(version) await session.commit() return scheme_id, scheme_version_id async def list_scheme_records(limit: int = 50, offset: int = 0) -> list[SchemeRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeRecord) .order_by(desc(SchemeRecord.created_at), desc(SchemeRecord.id)) .limit(limit) .offset(offset) ) return list(result.scalars().all()) async def count_scheme_records() -> int: async with AsyncSessionLocal() as session: result = await session.execute(select(func.count()).select_from(SchemeRecord)) return int(result.scalar_one()) async def get_scheme_record_by_scheme_id(scheme_id: str) -> SchemeRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeRecord).where(SchemeRecord.scheme_id == scheme_id) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Scheme not found", ) return row async def publish_scheme(scheme_id: str) -> SchemeRecord: async with AsyncSessionLocal() as session: async with session.begin(): scheme = await _get_scheme_for_update(session, scheme_id) version = await _get_current_version_for_scheme(session, scheme) scheme.status = "published" scheme.published_at = func.now() version.status = "published" await session.refresh(scheme) return scheme async def unpublish_scheme(scheme_id: str) -> SchemeRecord: async with AsyncSessionLocal() as session: async with session.begin(): scheme = await _get_scheme_for_update(session, scheme_id) version = await _get_current_version_for_scheme(session, scheme) scheme.status = "draft" scheme.published_at = None version.status = "draft" await session.refresh(scheme) return scheme async def rollback_scheme_to_version(scheme_id: str, target_version_number: int) -> SchemeRecord: async with AsyncSessionLocal() as session: async with session.begin(): scheme = await _get_scheme_for_update(session, scheme_id) current_version = await _get_current_version_for_scheme(session, scheme) target_result = await session.execute( select(SchemeVersionRecord).where( SchemeVersionRecord.scheme_id == scheme.scheme_id, SchemeVersionRecord.version_number == target_version_number, ) ) target_version = target_result.scalar_one_or_none() if target_version is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Target scheme version not found", ) current_version.status = "draft" target_version.status = "draft" scheme.current_version_number = target_version.version_number scheme.status = "draft" scheme.published_at = None scheme.normalized_elements_count = target_version.normalized_elements_count scheme.normalized_seats_count = target_version.normalized_seats_count scheme.normalized_groups_count = target_version.normalized_groups_count scheme.normalized_sectors_count = target_version.normalized_sectors_count await session.refresh(scheme) return scheme