from fastapi import HTTPException, status from sqlalchemy import asc, select from app.db.session import AsyncSessionLocal from app.models.scheme_seat import SchemeSeatRecord async def replace_scheme_version_seats( *, scheme_id: str, scheme_version_id: str, seats: list[dict], ) -> None: async with AsyncSessionLocal() as session: existing_result = await session.execute( select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == scheme_version_id) ) existing_rows = list(existing_result.scalars().all()) for row in existing_rows: await session.delete(row) for item in seats: row = SchemeSeatRecord( seat_record_id=item["seat_record_id"] if "seat_record_id" in item and item["seat_record_id"] else __import__("uuid").uuid4().hex, scheme_id=scheme_id, scheme_version_id=scheme_version_id, element_id=item.get("id"), seat_id=item.get("seat_id"), sector_id=item.get("sector_id"), group_id=item.get("group_id"), row_label=item.get("row"), seat_number=item.get("seat_number"), tag=item.get("tag"), classes_raw=str(item.get("classes")), x=item.get("x"), y=item.get("y"), cx=item.get("cx"), cy=item.get("cy"), width=item.get("width"), height=item.get("height"), ) session.add(row) await session.commit() async def clone_scheme_version_seats( *, source_scheme_version_id: str, target_scheme_version_id: str, ) -> None: async with AsyncSessionLocal() as session: await clone_scheme_version_seats_in_session( session=session, source_scheme_version_id=source_scheme_version_id, target_scheme_version_id=target_scheme_version_id, ) await session.commit() async def clone_scheme_version_seats_in_session( *, session, source_scheme_version_id: str, target_scheme_version_id: str, ) -> None: result = await session.execute( select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == source_scheme_version_id) ) rows = list(result.scalars().all()) for row in rows: cloned = SchemeSeatRecord( seat_record_id=__import__("uuid").uuid4().hex, scheme_id=row.scheme_id, scheme_version_id=target_scheme_version_id, element_id=row.element_id, seat_id=row.seat_id, sector_id=row.sector_id, group_id=row.group_id, row_label=row.row_label, seat_number=row.seat_number, tag=row.tag, classes_raw=row.classes_raw, x=row.x, y=row.y, cx=row.cx, cy=row.cy, width=row.width, height=row.height, ) session.add(cloned) async def list_scheme_version_seats(scheme_version_id: str) -> list[SchemeSeatRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord) .where(SchemeSeatRecord.scheme_version_id == scheme_version_id) .order_by(asc(SchemeSeatRecord.created_at), asc(SchemeSeatRecord.id)) ) return list(result.scalars().all()) async def get_scheme_version_seat_by_seat_id( *, scheme_version_id: str, seat_id: str, ) -> SchemeSeatRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.seat_id == seat_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Seat not found in current scheme version", ) return row async def get_scheme_version_seat_by_record_id( *, scheme_version_id: str, seat_record_id: str, ) -> SchemeSeatRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.seat_record_id == seat_record_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Seat record not found in current draft version", ) return row async def update_scheme_version_seat_by_record_id( *, scheme_version_id: str, seat_record_id: str, **update_data, ) -> SchemeSeatRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.seat_record_id == seat_record_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Seat record not found in current draft version", ) if "seat_id" in update_data: row.seat_id = update_data["seat_id"] if "sector_id" in update_data: row.sector_id = update_data["sector_id"] if "group_id" in update_data: row.group_id = update_data["group_id"] if "row_label" in update_data: row.row_label = update_data["row_label"] if "seat_number" in update_data: row.seat_number = update_data["seat_number"] await session.commit() await session.refresh(row) return row async def bulk_update_scheme_version_seats_by_record_id( *, scheme_version_id: str, items: list[dict], ) -> list[SchemeSeatRecord]: updated_rows: list[SchemeSeatRecord] = [] async with AsyncSessionLocal() as session: for item in items: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.seat_record_id == item["seat_record_id"], ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Seat record not found in current draft version: {item['seat_record_id']}", ) if "seat_id" in item: row.seat_id = item["seat_id"] if "sector_id" in item: row.sector_id = item["sector_id"] if "group_id" in item: row.group_id = item["group_id"] if "row_label" in item: row.row_label = item["row_label"] if "seat_number" in item: row.seat_number = item["seat_number"] updated_rows.append(row) await session.commit() for row in updated_rows: await session.refresh(row) return updated_rows async def bulk_remap_scheme_version_seats( *, scheme_version_id: str, items: list[dict], ) -> None: async with AsyncSessionLocal() as session: for item in items: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.seat_record_id == item["seat_record_id"], ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Seat record not found in current draft version: {item['seat_record_id']}", ) row.sector_id = item["after_sector_id"] row.group_id = item["after_group_id"] await session.commit() async def cascade_update_seat_sector_reference( *, scheme_version_id: str, old_sector_id: str | None, new_sector_id: str | None, ) -> int: if not old_sector_id or old_sector_id == new_sector_id: return 0 async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.sector_id == old_sector_id, ) ) rows = list(result.scalars().all()) for row in rows: row.sector_id = new_sector_id await session.commit() return len(rows) async def cascade_update_seat_group_reference( *, scheme_version_id: str, old_group_id: str | None, new_group_id: str | None, ) -> int: if not old_group_id or old_group_id == new_group_id: return 0 async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id, SchemeSeatRecord.group_id == old_group_id, ) ) rows = list(result.scalars().all()) for row in rows: row.group_id = new_group_id await session.commit() return len(rows) async def repair_orphan_sector_refs( *, scheme_version_id: str, new_sector_id: str, orphan_values: list[str], ) -> int: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id ) ) rows = list(result.scalars().all()) changed = 0 for row in rows: if row.sector_id in orphan_values: row.sector_id = new_sector_id changed += 1 await session.commit() return changed async def repair_orphan_group_refs( *, scheme_version_id: str, new_group_id: str, orphan_values: list[str], ) -> int: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeSeatRecord).where( SchemeSeatRecord.scheme_version_id == scheme_version_id ) ) rows = list(result.scalars().all()) changed = 0 for row in rows: if row.group_id in orphan_values: row.group_id = new_group_id changed += 1 await session.commit() return changed