from datetime import datetime from uuid import uuid4 from fastapi import HTTPException, status from sqlalchemy import asc, desc, func, select from app.db.session import AsyncSessionLocal from app.models.scheme import SchemeRecord from app.models.scheme_version import SchemeVersionRecord from app.repositories.scheme_groups import clone_scheme_version_groups_in_session from app.repositories.scheme_seats import clone_scheme_version_seats_in_session from app.repositories.scheme_sectors import clone_scheme_version_sectors_in_session from app.services.api_errors import raise_conflict def _raise_current_version_inconsistent(*, scheme_id: str, current_version_number: int) -> None: raise_conflict( code="current_version_inconsistent", message="Scheme current version pointer is inconsistent with scheme_versions state.", details={ "scheme_id": scheme_id, "current_version_number": current_version_number, }, ) def _raise_stale_current_version(*, expected_scheme_version_id: str, actual_scheme_version_id: str) -> None: raise_conflict( code="stale_current_version", message="Current scheme version changed. Reload scheme state before creating a new version.", details={ "expected_scheme_version_id": expected_scheme_version_id, "actual_scheme_version_id": actual_scheme_version_id, }, ) async def _get_scheme_for_update(session, scheme_id: str) -> SchemeRecord: scheme_result = await session.execute( select(SchemeRecord) .where(SchemeRecord.scheme_id == scheme_id) .with_for_update() ) scheme = scheme_result.scalar_one_or_none() if scheme is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Scheme not found", ) return scheme async def _get_current_scheme_version_for_update( session, *, scheme_id: str, current_version_number: int, ) -> SchemeVersionRecord: current_result = await session.execute( select(SchemeVersionRecord) .where( SchemeVersionRecord.scheme_id == scheme_id, SchemeVersionRecord.version_number == current_version_number, ) .with_for_update() ) current_version = current_result.scalar_one_or_none() if current_version is None: _raise_current_version_inconsistent( scheme_id=scheme_id, current_version_number=current_version_number, ) return current_version async def _build_next_draft_version( session, *, scheme: SchemeRecord, source_version: SchemeVersionRecord, ) -> SchemeVersionRecord: max_version_result = await session.execute( select(func.coalesce(func.max(SchemeVersionRecord.version_number), 0)).where( SchemeVersionRecord.scheme_id == scheme.scheme_id ) ) next_version_number = int(max_version_result.scalar_one()) + 1 new_version = SchemeVersionRecord( scheme_version_id=uuid4().hex, scheme_id=scheme.scheme_id, version_number=next_version_number, status="draft", normalized_storage_path=source_version.normalized_storage_path, normalized_elements_count=source_version.normalized_elements_count, normalized_seats_count=source_version.normalized_seats_count, normalized_groups_count=source_version.normalized_groups_count, normalized_sectors_count=source_version.normalized_sectors_count, display_svg_storage_path=source_version.display_svg_storage_path, display_svg_status=source_version.display_svg_status, display_svg_generated_at=source_version.display_svg_generated_at, ) session.add(new_version) await session.flush() await clone_scheme_version_sectors_in_session( session=session, source_scheme_version_id=source_version.scheme_version_id, target_scheme_version_id=new_version.scheme_version_id, ) await clone_scheme_version_groups_in_session( session=session, source_scheme_version_id=source_version.scheme_version_id, target_scheme_version_id=new_version.scheme_version_id, ) await clone_scheme_version_seats_in_session( session=session, source_scheme_version_id=source_version.scheme_version_id, target_scheme_version_id=new_version.scheme_version_id, ) scheme.current_version_number = new_version.version_number scheme.status = "draft" scheme.published_at = None scheme.normalized_elements_count = source_version.normalized_elements_count scheme.normalized_seats_count = source_version.normalized_seats_count scheme.normalized_groups_count = source_version.normalized_groups_count scheme.normalized_sectors_count = source_version.normalized_sectors_count return new_version async def create_initial_scheme_version( *, scheme_id: str, normalized_storage_path: str, normalized_elements_count: int, normalized_seats_count: int, normalized_groups_count: int, normalized_sectors_count: int, display_svg_storage_path: str | None = None, display_svg_status: str = "pending", display_svg_generated_at: datetime | None = None, ) -> str: scheme_version_id = uuid4().hex async with AsyncSessionLocal() as session: row = SchemeVersionRecord( scheme_version_id=scheme_version_id, scheme_id=scheme_id, version_number=1, status="draft", normalized_storage_path=normalized_storage_path, normalized_elements_count=normalized_elements_count, normalized_seats_count=normalized_seats_count, normalized_groups_count=normalized_groups_count, normalized_sectors_count=normalized_sectors_count, display_svg_storage_path=display_svg_storage_path, display_svg_status=display_svg_status, display_svg_generated_at=display_svg_generated_at, ) session.add(row) await session.commit() return scheme_version_id async def list_scheme_versions(scheme_id: str, limit: int = 100, offset: int = 0) -> list[SchemeVersionRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeVersionRecord) .where(SchemeVersionRecord.scheme_id == scheme_id) .order_by(asc(SchemeVersionRecord.version_number), desc(SchemeVersionRecord.id)) .limit(limit) .offset(offset) ) return list(result.scalars().all()) async def count_scheme_versions(scheme_id: str) -> int: async with AsyncSessionLocal() as session: result = await session.execute( select(func.count()).select_from(SchemeVersionRecord).where(SchemeVersionRecord.scheme_id == scheme_id) ) return int(result.scalar_one()) async def get_current_scheme_version(scheme_id: str, current_version_number: int) -> SchemeVersionRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeVersionRecord).where( SchemeVersionRecord.scheme_id == scheme_id, SchemeVersionRecord.version_number == current_version_number, ) ) row = result.scalar_one_or_none() if row is None: _raise_current_version_inconsistent( scheme_id=scheme_id, current_version_number=current_version_number, ) return row async def update_scheme_version_display_artifact( *, scheme_version_id: str, display_svg_storage_path: str, display_svg_status: str, display_svg_generated_at: datetime, ) -> None: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeVersionRecord).where( SchemeVersionRecord.scheme_version_id == scheme_version_id ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Scheme version not found", ) row.display_svg_storage_path = display_svg_storage_path row.display_svg_status = display_svg_status row.display_svg_generated_at = display_svg_generated_at await session.commit() async def create_next_scheme_version_from_current(scheme_id: str) -> SchemeVersionRecord: async with AsyncSessionLocal() as session: async with session.begin(): scheme = await _get_scheme_for_update(session, scheme_id) current_version = await _get_current_scheme_version_for_update( session, scheme_id=scheme.scheme_id, current_version_number=scheme.current_version_number, ) new_version = await _build_next_draft_version( session, scheme=scheme, source_version=current_version, ) await session.refresh(new_version) return new_version async def create_next_scheme_version_from_current_checked( *, scheme_id: str, expected_current_scheme_version_id: str | None = None, ) -> tuple[SchemeVersionRecord, SchemeVersionRecord]: async with AsyncSessionLocal() as session: async with session.begin(): scheme = await _get_scheme_for_update(session, scheme_id) current_version = await _get_current_scheme_version_for_update( session, scheme_id=scheme.scheme_id, current_version_number=scheme.current_version_number, ) if ( expected_current_scheme_version_id and expected_current_scheme_version_id != current_version.scheme_version_id ): _raise_stale_current_version( expected_scheme_version_id=expected_current_scheme_version_id, actual_scheme_version_id=current_version.scheme_version_id, ) new_version = await _build_next_draft_version( session, scheme=scheme, source_version=current_version, ) await session.refresh(current_version) await session.refresh(new_version) return current_version, new_version async def ensure_draft_scheme_version_consistent( *, scheme_id: str, expected_current_scheme_version_id: str | None = None, ) -> tuple[SchemeVersionRecord, bool, str | None]: async with AsyncSessionLocal() as session: async with session.begin(): scheme = await _get_scheme_for_update(session, scheme_id) current_version = await _get_current_scheme_version_for_update( session, scheme_id=scheme.scheme_id, current_version_number=scheme.current_version_number, ) if ( expected_current_scheme_version_id and expected_current_scheme_version_id != current_version.scheme_version_id ): _raise_stale_current_version( expected_scheme_version_id=expected_current_scheme_version_id, actual_scheme_version_id=current_version.scheme_version_id, ) if scheme.status == "draft" and current_version.status == "draft": await session.refresh(current_version) return current_version, False, None new_version = await _build_next_draft_version( session, scheme=scheme, source_version=current_version, ) source_scheme_version_id = current_version.scheme_version_id await session.refresh(new_version) return new_version, True, source_scheme_version_id