Files
svg-backend/backend/app/repositories/scheme_sectors.py

293 lines
9.6 KiB
Python

from uuid import uuid4
from fastapi import HTTPException, status
from sqlalchemy import asc, select
from app.db.session import AsyncSessionLocal
from app.models.scheme_sector import SchemeSectorRecord
from app.models.scheme_seat import SchemeSeatRecord
def _conflict(message: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"code": "sector_uniqueness_violation",
"message": message,
},
)
async def _ensure_sector_uniqueness(
*,
session,
scheme_version_id: str,
sector_id: str | None,
element_id: str | None,
exclude_sector_record_id: str | None = None,
) -> None:
if sector_id:
stmt = select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_id == sector_id,
)
if exclude_sector_record_id:
stmt = stmt.where(SchemeSectorRecord.sector_record_id != exclude_sector_record_id)
existing = (await session.execute(stmt)).scalar_one_or_none()
if existing is not None:
raise _conflict(f"Sector with sector_id='{sector_id}' already exists in current draft version")
if element_id:
stmt = select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.element_id == element_id,
)
if exclude_sector_record_id:
stmt = stmt.where(SchemeSectorRecord.sector_record_id != exclude_sector_record_id)
existing = (await session.execute(stmt)).scalar_one_or_none()
if existing is not None:
raise _conflict(f"Sector with element_id='{element_id}' already exists in current draft version")
async def replace_scheme_version_sectors(
*,
scheme_id: str,
scheme_version_id: str,
sectors: list[dict],
) -> None:
async with AsyncSessionLocal() as session:
existing_result = await session.execute(
select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == scheme_version_id)
)
existing_rows = list(existing_result.scalars().all())
for row in existing_rows:
await session.delete(row)
seen_sector_ids: set[str] = set()
seen_element_ids: set[str] = set()
for item in sectors:
sector_id = item.get("sector_id")
element_id = item.get("id")
if sector_id:
if sector_id in seen_sector_ids:
raise _conflict(f"Duplicate sector_id='{sector_id}' in replacement payload")
seen_sector_ids.add(sector_id)
if element_id:
if element_id in seen_element_ids:
raise _conflict(f"Duplicate element_id='{element_id}' in replacement payload")
seen_element_ids.add(element_id)
row = SchemeSectorRecord(
sector_record_id=item["sector_record_id"] if "sector_record_id" in item and item["sector_record_id"] else uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=element_id,
sector_id=sector_id,
name=item.get("sector_id"),
classes_raw=str(item.get("classes")),
)
session.add(row)
await session.commit()
async def clone_scheme_version_sectors(
*,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
async with AsyncSessionLocal() as session:
await clone_scheme_version_sectors_in_session(
session=session,
source_scheme_version_id=source_scheme_version_id,
target_scheme_version_id=target_scheme_version_id,
)
await session.commit()
async def clone_scheme_version_sectors_in_session(
*,
session,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
result = await session.execute(
select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
seen_sector_ids: set[str] = set()
seen_element_ids: set[str] = set()
for row in rows:
if row.sector_id:
if row.sector_id in seen_sector_ids:
raise _conflict(f"Duplicate sector_id='{row.sector_id}' while cloning draft")
seen_sector_ids.add(row.sector_id)
if row.element_id:
if row.element_id in seen_element_ids:
raise _conflict(f"Duplicate element_id='{row.element_id}' while cloning draft")
seen_element_ids.add(row.element_id)
cloned = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
sector_id=row.sector_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
async def list_scheme_version_sectors(scheme_version_id: str) -> list[SchemeSectorRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord)
.where(SchemeSectorRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeSectorRecord.created_at), asc(SchemeSectorRecord.id))
)
return list(result.scalars().all())
async def update_scheme_version_sector_by_record_id(
*,
scheme_version_id: str,
sector_record_id: str,
**update_data,
) -> tuple[SchemeSectorRecord, str | None]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_record_id == sector_record_id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sector record not found in current draft version",
)
if "sector_id" in update_data:
await _ensure_sector_uniqueness(
session=session,
scheme_version_id=scheme_version_id,
sector_id=update_data["sector_id"],
element_id=row.element_id,
exclude_sector_record_id=sector_record_id,
)
old_sector_id = row.sector_id
if "sector_id" in update_data:
row.sector_id = update_data["sector_id"]
if "name" in update_data:
row.name = update_data["name"]
await session.commit()
await session.refresh(row)
return row, old_sector_id
async def create_scheme_version_sector(
*,
scheme_id: str,
scheme_version_id: str,
element_id: str | None,
sector_id: str,
name: str | None,
classes_raw: str | None,
) -> SchemeSectorRecord:
async with AsyncSessionLocal() as session:
await _ensure_sector_uniqueness(
session=session,
scheme_version_id=scheme_version_id,
sector_id=sector_id,
element_id=element_id,
)
row = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=element_id,
sector_id=sector_id,
name=name,
classes_raw=classes_raw,
)
session.add(row)
await session.commit()
await session.refresh(row)
return row
async def delete_scheme_version_sector_by_record_id(
*,
scheme_version_id: str,
sector_record_id: str,
) -> None:
async with AsyncSessionLocal() as session:
sector_result = await session.execute(
select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_record_id == sector_record_id,
)
)
sector = sector_result.scalar_one_or_none()
if sector is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sector record not found in current draft version",
)
if sector.sector_id:
seats_result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.sector_id == sector.sector_id,
)
)
seats = list(seats_result.scalars().all())
if seats:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete sector while seats still reference it",
)
await session.delete(sector)
await session.commit()
async def get_scheme_version_sector_by_record_id(
*,
scheme_version_id: str,
sector_record_id: str,
) -> SchemeSectorRecord:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_record_id == sector_record_id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sector record not found in current draft version",
)
return row