293 lines
9.5 KiB
Python
293 lines
9.5 KiB
Python
from uuid import uuid4
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import asc, select
|
|
|
|
from app.db.session import AsyncSessionLocal
|
|
from app.models.scheme_group import SchemeGroupRecord
|
|
from app.models.scheme_seat import SchemeSeatRecord
|
|
|
|
|
|
def _conflict(message: str) -> HTTPException:
|
|
return HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail={
|
|
"code": "group_uniqueness_violation",
|
|
"message": message,
|
|
},
|
|
)
|
|
|
|
|
|
async def _ensure_group_uniqueness(
|
|
*,
|
|
session,
|
|
scheme_version_id: str,
|
|
group_id: str | None,
|
|
element_id: str | None,
|
|
exclude_group_record_id: str | None = None,
|
|
) -> None:
|
|
if group_id:
|
|
stmt = select(SchemeGroupRecord).where(
|
|
SchemeGroupRecord.scheme_version_id == scheme_version_id,
|
|
SchemeGroupRecord.group_id == group_id,
|
|
)
|
|
if exclude_group_record_id:
|
|
stmt = stmt.where(SchemeGroupRecord.group_record_id != exclude_group_record_id)
|
|
|
|
existing = (await session.execute(stmt)).scalar_one_or_none()
|
|
if existing is not None:
|
|
raise _conflict(f"Group with group_id='{group_id}' already exists in current draft version")
|
|
|
|
if element_id:
|
|
stmt = select(SchemeGroupRecord).where(
|
|
SchemeGroupRecord.scheme_version_id == scheme_version_id,
|
|
SchemeGroupRecord.element_id == element_id,
|
|
)
|
|
if exclude_group_record_id:
|
|
stmt = stmt.where(SchemeGroupRecord.group_record_id != exclude_group_record_id)
|
|
|
|
existing = (await session.execute(stmt)).scalar_one_or_none()
|
|
if existing is not None:
|
|
raise _conflict(f"Group with element_id='{element_id}' already exists in current draft version")
|
|
|
|
|
|
async def replace_scheme_version_groups(
|
|
*,
|
|
scheme_id: str,
|
|
scheme_version_id: str,
|
|
groups: list[dict],
|
|
) -> None:
|
|
async with AsyncSessionLocal() as session:
|
|
existing_result = await session.execute(
|
|
select(SchemeGroupRecord).where(SchemeGroupRecord.scheme_version_id == scheme_version_id)
|
|
)
|
|
existing_rows = list(existing_result.scalars().all())
|
|
|
|
for row in existing_rows:
|
|
await session.delete(row)
|
|
|
|
seen_group_ids: set[str] = set()
|
|
seen_element_ids: set[str] = set()
|
|
|
|
for item in groups:
|
|
group_id = item.get("group_id")
|
|
element_id = item.get("id")
|
|
|
|
if group_id:
|
|
if group_id in seen_group_ids:
|
|
raise _conflict(f"Duplicate group_id='{group_id}' in replacement payload")
|
|
seen_group_ids.add(group_id)
|
|
|
|
if element_id:
|
|
if element_id in seen_element_ids:
|
|
raise _conflict(f"Duplicate element_id='{element_id}' in replacement payload")
|
|
seen_element_ids.add(element_id)
|
|
|
|
row = SchemeGroupRecord(
|
|
group_record_id=item["group_record_id"] if "group_record_id" in item and item["group_record_id"] else uuid4().hex,
|
|
scheme_id=scheme_id,
|
|
scheme_version_id=scheme_version_id,
|
|
element_id=element_id,
|
|
group_id=group_id,
|
|
name=item.get("group_id"),
|
|
classes_raw=str(item.get("classes")),
|
|
)
|
|
session.add(row)
|
|
|
|
await session.commit()
|
|
|
|
|
|
async def clone_scheme_version_groups(
|
|
*,
|
|
source_scheme_version_id: str,
|
|
target_scheme_version_id: str,
|
|
) -> None:
|
|
async with AsyncSessionLocal() as session:
|
|
await clone_scheme_version_groups_in_session(
|
|
session=session,
|
|
source_scheme_version_id=source_scheme_version_id,
|
|
target_scheme_version_id=target_scheme_version_id,
|
|
)
|
|
await session.commit()
|
|
|
|
|
|
async def clone_scheme_version_groups_in_session(
|
|
*,
|
|
session,
|
|
source_scheme_version_id: str,
|
|
target_scheme_version_id: str,
|
|
) -> None:
|
|
result = await session.execute(
|
|
select(SchemeGroupRecord).where(SchemeGroupRecord.scheme_version_id == source_scheme_version_id)
|
|
)
|
|
rows = list(result.scalars().all())
|
|
|
|
seen_group_ids: set[str] = set()
|
|
seen_element_ids: set[str] = set()
|
|
|
|
for row in rows:
|
|
if row.group_id:
|
|
if row.group_id in seen_group_ids:
|
|
raise _conflict(f"Duplicate group_id='{row.group_id}' while cloning draft")
|
|
seen_group_ids.add(row.group_id)
|
|
|
|
if row.element_id:
|
|
if row.element_id in seen_element_ids:
|
|
raise _conflict(f"Duplicate element_id='{row.element_id}' while cloning draft")
|
|
seen_element_ids.add(row.element_id)
|
|
|
|
cloned = SchemeGroupRecord(
|
|
group_record_id=uuid4().hex,
|
|
scheme_id=row.scheme_id,
|
|
scheme_version_id=target_scheme_version_id,
|
|
element_id=row.element_id,
|
|
group_id=row.group_id,
|
|
name=row.name,
|
|
classes_raw=row.classes_raw,
|
|
)
|
|
session.add(cloned)
|
|
|
|
|
|
async def list_scheme_version_groups(scheme_version_id: str) -> list[SchemeGroupRecord]:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeGroupRecord)
|
|
.where(SchemeGroupRecord.scheme_version_id == scheme_version_id)
|
|
.order_by(asc(SchemeGroupRecord.created_at), asc(SchemeGroupRecord.id))
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def update_scheme_version_group_by_record_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
group_record_id: str,
|
|
**update_data,
|
|
) -> tuple[SchemeGroupRecord, str | None]:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeGroupRecord).where(
|
|
SchemeGroupRecord.scheme_version_id == scheme_version_id,
|
|
SchemeGroupRecord.group_record_id == group_record_id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Group record not found in current draft version",
|
|
)
|
|
|
|
if "group_id" in update_data:
|
|
await _ensure_group_uniqueness(
|
|
session=session,
|
|
scheme_version_id=scheme_version_id,
|
|
group_id=update_data["group_id"],
|
|
element_id=row.element_id,
|
|
exclude_group_record_id=group_record_id,
|
|
)
|
|
|
|
old_group_id = row.group_id
|
|
if "group_id" in update_data:
|
|
row.group_id = update_data["group_id"]
|
|
if "name" in update_data:
|
|
row.name = update_data["name"]
|
|
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row, old_group_id
|
|
|
|
|
|
async def create_scheme_version_group(
|
|
*,
|
|
scheme_id: str,
|
|
scheme_version_id: str,
|
|
element_id: str | None,
|
|
group_id: str,
|
|
name: str | None,
|
|
classes_raw: str | None,
|
|
) -> SchemeGroupRecord:
|
|
async with AsyncSessionLocal() as session:
|
|
await _ensure_group_uniqueness(
|
|
session=session,
|
|
scheme_version_id=scheme_version_id,
|
|
group_id=group_id,
|
|
element_id=element_id,
|
|
)
|
|
|
|
row = SchemeGroupRecord(
|
|
group_record_id=uuid4().hex,
|
|
scheme_id=scheme_id,
|
|
scheme_version_id=scheme_version_id,
|
|
element_id=element_id,
|
|
group_id=group_id,
|
|
name=name,
|
|
classes_raw=classes_raw,
|
|
)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row
|
|
|
|
|
|
async def delete_scheme_version_group_by_record_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
group_record_id: str,
|
|
) -> None:
|
|
async with AsyncSessionLocal() as session:
|
|
group_result = await session.execute(
|
|
select(SchemeGroupRecord).where(
|
|
SchemeGroupRecord.scheme_version_id == scheme_version_id,
|
|
SchemeGroupRecord.group_record_id == group_record_id,
|
|
)
|
|
)
|
|
group = group_result.scalar_one_or_none()
|
|
|
|
if group is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Group record not found in current draft version",
|
|
)
|
|
|
|
if group.group_id:
|
|
seats_result = await session.execute(
|
|
select(SchemeSeatRecord).where(
|
|
SchemeSeatRecord.scheme_version_id == scheme_version_id,
|
|
SchemeSeatRecord.group_id == group.group_id,
|
|
)
|
|
)
|
|
seats = list(seats_result.scalars().all())
|
|
if seats:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Cannot delete group while seats still reference it",
|
|
)
|
|
|
|
await session.delete(group)
|
|
await session.commit()
|
|
|
|
|
|
async def get_scheme_version_group_by_record_id(
|
|
*,
|
|
scheme_version_id: str,
|
|
group_record_id: str,
|
|
) -> SchemeGroupRecord:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SchemeGroupRecord).where(
|
|
SchemeGroupRecord.scheme_version_id == scheme_version_id,
|
|
SchemeGroupRecord.group_record_id == group_record_id,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
|
|
if row is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Group record not found in current draft version",
|
|
)
|
|
|
|
return row
|