354 lines
11 KiB
Python
354 lines
11 KiB
Python
from fastapi import HTTPException, status
|
|
from sqlalchemy import asc, select
|
|
|
|
from app.db.session import AsyncSessionLocal
|
|
from app.models.scheme_seat import SchemeSeatRecord
|
|
|
|
|
|
async def replace_scheme_version_seats(
|
|
*,
|
|
scheme_id: str,
|
|
scheme_version_id: str,
|
|
seats: list[dict],
|
|
) -> None:
|
|
async with AsyncSessionLocal() as session:
|
|
existing_result = await session.execute(
|
|
select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == scheme_version_id)
|
|
)
|
|
existing_rows = list(existing_result.scalars().all())
|
|
|
|
for row in existing_rows:
|
|
await session.delete(row)
|
|
|
|
for item in seats:
|
|
row = SchemeSeatRecord(
|
|
seat_record_id=item["seat_record_id"] if "seat_record_id" in item and item["seat_record_id"] else __import__("uuid").uuid4().hex,
|
|
scheme_id=scheme_id,
|
|
scheme_version_id=scheme_version_id,
|
|
element_id=item.get("id"),
|
|
seat_id=item.get("seat_id"),
|
|
sector_id=item.get("sector_id"),
|
|
group_id=item.get("group_id"),
|
|
row_label=item.get("row"),
|
|
seat_number=item.get("seat_number"),
|
|
tag=item.get("tag"),
|
|
classes_raw=str(item.get("classes")),
|
|
x=item.get("x"),
|
|
y=item.get("y"),
|
|
cx=item.get("cx"),
|
|
cy=item.get("cy"),
|
|
width=item.get("width"),
|
|
height=item.get("height"),
|
|
)
|
|
session.add(row)
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def clone_scheme_version_seats(
|
|
*,
|
|
source_scheme_version_id: str,
|
|
target_scheme_version_id: str,
|
|
) -> None:
|
|
async with AsyncSessionLocal() as session:
|
|
await clone_scheme_version_seats_in_session(
|
|
session=session,
|
|
source_scheme_version_id=source_scheme_version_id,
|
|
target_scheme_version_id=target_scheme_version_id,
|
|
)
|
|
await session.commit()
|
|
|
|
|
|
async def clone_scheme_version_seats_in_session(
|
|
*,
|
|
session,
|
|
source_scheme_version_id: str,
|
|
target_scheme_version_id: str,
|
|
) -> None:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == source_scheme_version_id)
|
|
)
|
|
rows = list(result.scalars().all())
|
|
|
|
for row in rows:
|
|
cloned = SchemeSeatRecord(
|
|
seat_record_id=__import__("uuid").uuid4().hex,
|
|
scheme_id=row.scheme_id,
|
|
scheme_version_id=target_scheme_version_id,
|
|
element_id=row.element_id,
|
|
seat_id=row.seat_id,
|
|
sector_id=row.sector_id,
|
|
group_id=row.group_id,
|
|
row_label=row.row_label,
|
|
seat_number=row.seat_number,
|
|
tag=row.tag,
|
|
classes_raw=row.classes_raw,
|
|
x=row.x,
|
|
y=row.y,
|
|
cx=row.cx,
|
|
cy=row.cy,
|
|
width=row.width,
|
|
height=row.height,
|
|
)
|
|
session.add(cloned)
|
|
|
|
|
|
async def list_scheme_version_seats(scheme_version_id: str) -> list[SchemeSeatRecord]:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord)
|
|
.where(SchemeSeatRecord.scheme_version_id == scheme_version_id)
|
|
.order_by(asc(SchemeSeatRecord.created_at), asc(SchemeSeatRecord.id))
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def get_scheme_version_seat_by_seat_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
seat_id: str,
|
|
) -> SchemeSeatRecord:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.seat_id == seat_id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Seat not found in current scheme version",
|
|
)
|
|
|
|
return row
|
|
|
|
|
|
async def get_scheme_version_seat_by_record_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
seat_record_id: str,
|
|
) -> SchemeSeatRecord:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.seat_record_id == seat_record_id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Seat record not found in current draft version",
|
|
)
|
|
|
|
return row
|
|
|
|
|
|
async def update_scheme_version_seat_by_record_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
seat_record_id: str,
|
|
**update_data,
|
|
) -> SchemeSeatRecord:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.seat_record_id == seat_record_id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Seat record not found in current draft version",
|
|
)
|
|
|
|
if "seat_id" in update_data:
|
|
row.seat_id = update_data["seat_id"]
|
|
if "sector_id" in update_data:
|
|
row.sector_id = update_data["sector_id"]
|
|
if "group_id" in update_data:
|
|
row.group_id = update_data["group_id"]
|
|
if "row_label" in update_data:
|
|
row.row_label = update_data["row_label"]
|
|
if "seat_number" in update_data:
|
|
row.seat_number = update_data["seat_number"]
|
|
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row
|
|
|
|
|
|
async def bulk_update_scheme_version_seats_by_record_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
items: list[dict],
|
|
) -> list[SchemeSeatRecord]:
|
|
updated_rows: list[SchemeSeatRecord] = []
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
for item in items:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.seat_record_id == item["seat_record_id"],
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Seat record not found in current draft version: {item['seat_record_id']}",
|
|
)
|
|
|
|
if "seat_id" in item:
|
|
row.seat_id = item["seat_id"]
|
|
if "sector_id" in item:
|
|
row.sector_id = item["sector_id"]
|
|
if "group_id" in item:
|
|
row.group_id = item["group_id"]
|
|
if "row_label" in item:
|
|
row.row_label = item["row_label"]
|
|
if "seat_number" in item:
|
|
row.seat_number = item["seat_number"]
|
|
updated_rows.append(row)
|
|
|
|
await session.commit()
|
|
|
|
for row in updated_rows:
|
|
await session.refresh(row)
|
|
|
|
return updated_rows
|
|
|
|
|
|
async def bulk_remap_scheme_version_seats(
|
|
*,
|
|
scheme_version_id: str,
|
|
items: list[dict],
|
|
) -> None:
|
|
async with AsyncSessionLocal() as session:
|
|
for item in items:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.seat_record_id == item["seat_record_id"],
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Seat record not found in current draft version: {item['seat_record_id']}",
|
|
)
|
|
|
|
row.sector_id = item["after_sector_id"]
|
|
row.group_id = item["after_group_id"]
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def cascade_update_seat_sector_reference(
|
|
*,
|
|
scheme_version_id: str,
|
|
old_sector_id: str | None,
|
|
new_sector_id: str | None,
|
|
) -> int:
|
|
if not old_sector_id or old_sector_id == new_sector_id:
|
|
return 0
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.sector_id == old_sector_id,
|
|
)
|
|
)
|
|
rows = list(result.scalars().all())
|
|
|
|
for row in rows:
|
|
row.sector_id = new_sector_id
|
|
|
|
await session.commit()
|
|
return len(rows)
|
|
|
|
|
|
async def cascade_update_seat_group_reference(
|
|
*,
|
|
scheme_version_id: str,
|
|
old_group_id: str | None,
|
|
new_group_id: str | None,
|
|
) -> int:
|
|
if not old_group_id or old_group_id == new_group_id:
|
|
return 0
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.group_id == old_group_id,
|
|
)
|
|
)
|
|
rows = list(result.scalars().all())
|
|
|
|
for row in rows:
|
|
row.group_id = new_group_id
|
|
|
|
await session.commit()
|
|
return len(rows)
|
|
|
|
|
|
async def repair_orphan_sector_refs(
|
|
*,
|
|
scheme_version_id: str,
|
|
new_sector_id: str,
|
|
orphan_values: list[str],
|
|
) -> int:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id
|
|
)
|
|
)
|
|
rows = list(result.scalars().all())
|
|
|
|
changed = 0
|
|
for row in rows:
|
|
if row.sector_id in orphan_values:
|
|
row.sector_id = new_sector_id
|
|
changed += 1
|
|
|
|
await session.commit()
|
|
return changed
|
|
|
|
|
|
async def repair_orphan_group_refs(
|
|
*,
|
|
scheme_version_id: str,
|
|
new_group_id: str,
|
|
orphan_values: list[str],
|
|
) -> int:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id
|
|
)
|
|
)
|
|
rows = list(result.scalars().all())
|
|
|
|
changed = 0
|
|
for row in rows:
|
|
if row.group_id in orphan_values:
|
|
row.group_id = new_group_id
|
|
changed += 1
|
|
|
|
await session.commit()
|
|
return changed
|