Files
svg-backend/backend/app/repositories/scheme_versions.py

319 lines
12 KiB
Python

from datetime import datetime
from uuid import uuid4
from fastapi import HTTPException, status
from sqlalchemy import asc, desc, func, select
from app.db.session import AsyncSessionLocal
from app.models.scheme import SchemeRecord
from app.models.scheme_version import SchemeVersionRecord
from app.repositories.scheme_groups import clone_scheme_version_groups_in_session
from app.repositories.scheme_seats import clone_scheme_version_seats_in_session
from app.repositories.scheme_sectors import clone_scheme_version_sectors_in_session
from app.services.api_errors import raise_conflict
def _raise_current_version_inconsistent(*, scheme_id: str, current_version_number: int) -> None:
raise_conflict(
code="current_version_inconsistent",
message="Scheme current version pointer is inconsistent with scheme_versions state.",
details={
"scheme_id": scheme_id,
"current_version_number": current_version_number,
},
)
def _raise_stale_current_version(*, expected_scheme_version_id: str, actual_scheme_version_id: str) -> None:
raise_conflict(
code="stale_current_version",
message="Current scheme version changed. Reload scheme state before creating a new version.",
details={
"expected_scheme_version_id": expected_scheme_version_id,
"actual_scheme_version_id": actual_scheme_version_id,
},
)
async def _get_scheme_for_update(session, scheme_id: str) -> SchemeRecord:
scheme_result = await session.execute(
select(SchemeRecord)
.where(SchemeRecord.scheme_id == scheme_id)
.with_for_update()
)
scheme = scheme_result.scalar_one_or_none()
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
return scheme
async def _get_current_scheme_version_for_update(
session,
*,
scheme_id: str,
current_version_number: int,
) -> SchemeVersionRecord:
current_result = await session.execute(
select(SchemeVersionRecord)
.where(
SchemeVersionRecord.scheme_id == scheme_id,
SchemeVersionRecord.version_number == current_version_number,
)
.with_for_update()
)
current_version = current_result.scalar_one_or_none()
if current_version is None:
_raise_current_version_inconsistent(
scheme_id=scheme_id,
current_version_number=current_version_number,
)
return current_version
async def _build_next_draft_version(
session,
*,
scheme: SchemeRecord,
source_version: SchemeVersionRecord,
) -> SchemeVersionRecord:
max_version_result = await session.execute(
select(func.coalesce(func.max(SchemeVersionRecord.version_number), 0)).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id
)
)
next_version_number = int(max_version_result.scalar_one()) + 1
new_version = SchemeVersionRecord(
scheme_version_id=uuid4().hex,
scheme_id=scheme.scheme_id,
version_number=next_version_number,
status="draft",
normalized_storage_path=source_version.normalized_storage_path,
normalized_elements_count=source_version.normalized_elements_count,
normalized_seats_count=source_version.normalized_seats_count,
normalized_groups_count=source_version.normalized_groups_count,
normalized_sectors_count=source_version.normalized_sectors_count,
display_svg_storage_path=source_version.display_svg_storage_path,
display_svg_status=source_version.display_svg_status,
display_svg_generated_at=source_version.display_svg_generated_at,
)
session.add(new_version)
await session.flush()
await clone_scheme_version_sectors_in_session(
session=session,
source_scheme_version_id=source_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_groups_in_session(
session=session,
source_scheme_version_id=source_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_seats_in_session(
session=session,
source_scheme_version_id=source_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
scheme.current_version_number = new_version.version_number
scheme.status = "draft"
scheme.published_at = None
scheme.normalized_elements_count = source_version.normalized_elements_count
scheme.normalized_seats_count = source_version.normalized_seats_count
scheme.normalized_groups_count = source_version.normalized_groups_count
scheme.normalized_sectors_count = source_version.normalized_sectors_count
return new_version
async def create_initial_scheme_version(
*,
scheme_id: str,
normalized_storage_path: str,
normalized_elements_count: int,
normalized_seats_count: int,
normalized_groups_count: int,
normalized_sectors_count: int,
display_svg_storage_path: str | None = None,
display_svg_status: str = "pending",
display_svg_generated_at: datetime | None = None,
) -> str:
scheme_version_id = uuid4().hex
async with AsyncSessionLocal() as session:
row = SchemeVersionRecord(
scheme_version_id=scheme_version_id,
scheme_id=scheme_id,
version_number=1,
status="draft",
normalized_storage_path=normalized_storage_path,
normalized_elements_count=normalized_elements_count,
normalized_seats_count=normalized_seats_count,
normalized_groups_count=normalized_groups_count,
normalized_sectors_count=normalized_sectors_count,
display_svg_storage_path=display_svg_storage_path,
display_svg_status=display_svg_status,
display_svg_generated_at=display_svg_generated_at,
)
session.add(row)
await session.commit()
return scheme_version_id
async def list_scheme_versions(scheme_id: str, limit: int = 100, offset: int = 0) -> list[SchemeVersionRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeVersionRecord)
.where(SchemeVersionRecord.scheme_id == scheme_id)
.order_by(asc(SchemeVersionRecord.version_number), desc(SchemeVersionRecord.id))
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
async def count_scheme_versions(scheme_id: str) -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(func.count()).select_from(SchemeVersionRecord).where(SchemeVersionRecord.scheme_id == scheme_id)
)
return int(result.scalar_one())
async def get_current_scheme_version(scheme_id: str, current_version_number: int) -> SchemeVersionRecord:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme_id,
SchemeVersionRecord.version_number == current_version_number,
)
)
row = result.scalar_one_or_none()
if row is None:
_raise_current_version_inconsistent(
scheme_id=scheme_id,
current_version_number=current_version_number,
)
return row
async def update_scheme_version_display_artifact(
*,
scheme_version_id: str,
display_svg_storage_path: str,
display_svg_status: str,
display_svg_generated_at: datetime,
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_version_id == scheme_version_id
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme version not found",
)
row.display_svg_storage_path = display_svg_storage_path
row.display_svg_status = display_svg_status
row.display_svg_generated_at = display_svg_generated_at
await session.commit()
async def create_next_scheme_version_from_current(scheme_id: str) -> SchemeVersionRecord:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_scheme_version_for_update(
session,
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
new_version = await _build_next_draft_version(
session,
scheme=scheme,
source_version=current_version,
)
await session.refresh(new_version)
return new_version
async def create_next_scheme_version_from_current_checked(
*,
scheme_id: str,
expected_current_scheme_version_id: str | None = None,
) -> tuple[SchemeVersionRecord, SchemeVersionRecord]:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_scheme_version_for_update(
session,
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
if (
expected_current_scheme_version_id
and expected_current_scheme_version_id != current_version.scheme_version_id
):
_raise_stale_current_version(
expected_scheme_version_id=expected_current_scheme_version_id,
actual_scheme_version_id=current_version.scheme_version_id,
)
new_version = await _build_next_draft_version(
session,
scheme=scheme,
source_version=current_version,
)
await session.refresh(current_version)
await session.refresh(new_version)
return current_version, new_version
async def ensure_draft_scheme_version_consistent(
*,
scheme_id: str,
expected_current_scheme_version_id: str | None = None,
) -> tuple[SchemeVersionRecord, bool, str | None]:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_scheme_version_for_update(
session,
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
if (
expected_current_scheme_version_id
and expected_current_scheme_version_id != current_version.scheme_version_id
):
_raise_stale_current_version(
expected_scheme_version_id=expected_current_scheme_version_id,
actual_scheme_version_id=current_version.scheme_version_id,
)
if scheme.status == "draft" and current_version.status == "draft":
await session.refresh(current_version)
return current_version, False, None
new_version = await _build_next_draft_version(
session,
scheme=scheme,
source_version=current_version,
)
source_scheme_version_id = current_version.scheme_version_id
await session.refresh(new_version)
return new_version, True, source_scheme_version_id