from uuid import uuid4 from fastapi import HTTPException, status from sqlalchemy import asc, select from app.db.session import AsyncSessionLocal from app.models.scheme_sector import SchemeSectorRecord from app.models.scheme_seat import SchemeSeatRecord def _conflict(message: str) -> HTTPException: return HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail={ "code": "sector_uniqueness_violation", "message": message, }, ) async def _ensure_sector_uniqueness( *, session, scheme_version_id: str, sector_id: str | None, element_id: str | None, exclude_sector_record_id: str | None = None, ) -> None: if sector_id: stmt = select(SchemeSectorRecord).where( SchemeSectorRecord.scheme_version_id == scheme_version_id, SchemeSectorRecord.sector_id == sector_id, ) if exclude_sector_record_id: stmt = stmt.where(SchemeSectorRecord.sector_record_id != exclude_sector_record_id) existing = (await session.execute(stmt)).scalar_one_or_none() if existing is not None: raise _conflict(f"Sector with sector_id='{sector_id}' already exists in current draft version") if element_id: stmt = select(SchemeSectorRecord).where( SchemeSectorRecord.scheme_version_id == scheme_version_id, SchemeSectorRecord.element_id == element_id, ) if exclude_sector_record_id: stmt = stmt.where(SchemeSectorRecord.sector_record_id != exclude_sector_record_id) existing = (await session.execute(stmt)).scalar_one_or_none() if existing is not None: raise _conflict(f"Sector with element_id='{element_id}' already exists in current draft version") async def replace_scheme_version_sectors( *, scheme_id: str, scheme_version_id: str, sectors: list[dict], ) -> None: async with AsyncSessionLocal() as session: existing_result = await session.execute( select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == scheme_version_id) ) existing_rows = list(existing_result.scalars().all()) for row in existing_rows: await session.delete(row) seen_sector_ids: set[str] = set() seen_element_ids: set[str] = set() for item in sectors: sector_id = item.get("sector_id") element_id = item.get("id") if sector_id: if sector_id in seen_sector_ids: raise _conflict(f"Duplicate sector_id='{sector_id}' in replacement payload") seen_sector_ids.add(sector_id) if element_id: if element_id in seen_element_ids: raise _conflict(f"Duplicate element_id='{element_id}' in replacement payload") seen_element_ids.add(element_id) row = SchemeSectorRecord( sector_record_id=item["sector_record_id"] if "sector_record_id" in item and item["sector_record_id"] else uuid4().hex, scheme_id=scheme_id, scheme_version_id=scheme_version_id, element_id=element_id, sector_id=sector_id, name=item.get("sector_id"), classes_raw=str(item.get("classes")), ) session.add(row) await session.commit() async def clone_scheme_version_sectors( *, source_scheme_version_id: str, target_scheme_version_id: str, ) -> None: async with AsyncSessionLocal() as session: await clone_scheme_version_sectors_in_session( session=session, source_scheme_version_id=source_scheme_version_id, target_scheme_version_id=target_scheme_version_id, ) await session.commit() async def clone_scheme_version_sectors_in_session( *, session, source_scheme_version_id: str, target_scheme_version_id: str, ) -> None: result = await session.execute( select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == source_scheme_version_id) ) rows = list(result.scalars().all()) seen_sector_ids: set[str] = set() seen_element_ids: set[str] = set() for row in rows: if row.sector_id: if row.sector_id in seen_sector_ids: raise _conflict(f"Duplicate sector_id='{row.sector_id}' while cloning draft") seen_sector_ids.add(row.sector_id) if row.element_id: if row.element_id in seen_element_ids: raise _conflict(f"Duplicate element_id='{row.element_id}' while cloning draft") seen_element_ids.add(row.element_id) cloned = SchemeSectorRecord( sector_record_id=uuid4().hex, scheme_id=row.scheme_id, scheme_version_id=target_scheme_version_id, element_id=row.element_id, sector_id=row.sector_id, name=row.name, classes_raw=row.classes_raw, ) session.add(cloned) async def list_scheme_version_sectors(scheme_version_id: str) -> list[SchemeSectorRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSectorRecord) .where(SchemeSectorRecord.scheme_version_id == scheme_version_id) .order_by(asc(SchemeSectorRecord.created_at), asc(SchemeSectorRecord.id)) ) return list(result.scalars().all()) async def update_scheme_version_sector_by_record_id( *, scheme_version_id: str, sector_record_id: str, **update_data, ) -> tuple[SchemeSectorRecord, str | None]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSectorRecord).where( SchemeSectorRecord.scheme_version_id == scheme_version_id, SchemeSectorRecord.sector_record_id == sector_record_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Sector record not found in current draft version", ) if "sector_id" in update_data: await _ensure_sector_uniqueness( session=session, scheme_version_id=scheme_version_id, sector_id=update_data["sector_id"], element_id=row.element_id, exclude_sector_record_id=sector_record_id, ) old_sector_id = row.sector_id if "sector_id" in update_data: row.sector_id = update_data["sector_id"] if "name" in update_data: row.name = update_data["name"] await session.commit() await session.refresh(row) return row, old_sector_id async def create_scheme_version_sector( *, scheme_id: str, scheme_version_id: str, element_id: str | None, sector_id: str, name: str | None, classes_raw: str | None, ) -> SchemeSectorRecord: async with AsyncSessionLocal() as session: await _ensure_sector_uniqueness( session=session, scheme_version_id=scheme_version_id, sector_id=sector_id, element_id=element_id, ) row = SchemeSectorRecord( sector_record_id=uuid4().hex, scheme_id=scheme_id, scheme_version_id=scheme_version_id, element_id=element_id, sector_id=sector_id, name=name, classes_raw=classes_raw, ) session.add(row) await session.commit() await session.refresh(row) return row async def delete_scheme_version_sector_by_record_id( *, scheme_version_id: str, sector_record_id: str, ) -> None: async with AsyncSessionLocal() as session: sector_result = await session.execute( select(SchemeSectorRecord).where( SchemeSectorRecord.scheme_version_id == scheme_version_id, SchemeSectorRecord.sector_record_id == sector_record_id, ) ) sector = sector_result.scalar_one_or_none() if sector is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Sector record not found in current draft version", ) if sector.sector_id: seats_result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.sector_id == sector.sector_id, ) ) seats = list(seats_result.scalars().all()) if seats: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Cannot delete sector while seats still reference it", ) await session.delete(sector) await session.commit() async def get_scheme_version_sector_by_record_id( *, scheme_version_id: str, sector_record_id: str, ) -> SchemeSectorRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSectorRecord).where( SchemeSectorRecord.scheme_version_id == scheme_version_id, SchemeSectorRecord.sector_record_id == sector_record_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Sector record not found in current draft version", ) return row