Files

226 lines
7.8 KiB
Python

from uuid import uuid4
from fastapi import HTTPException, status
from sqlalchemy import desc, func, select
from app.db.session import AsyncSessionLocal
from app.models.scheme import SchemeRecord
from app.models.scheme_version import SchemeVersionRecord
from app.services.api_errors import raise_conflict
def _raise_current_version_inconsistent(*, scheme_id: str, current_version_number: int) -> None:
raise_conflict(
code="current_version_inconsistent",
message="Scheme current version pointer is inconsistent with scheme_versions state.",
details={
"scheme_id": scheme_id,
"current_version_number": current_version_number,
},
)
async def _get_scheme_for_update(session, scheme_id: str) -> SchemeRecord:
scheme_result = await session.execute(
select(SchemeRecord)
.where(SchemeRecord.scheme_id == scheme_id)
.with_for_update()
)
scheme = scheme_result.scalar_one_or_none()
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
return scheme
async def _get_current_version_for_scheme(session, scheme: SchemeRecord) -> SchemeVersionRecord:
version_result = await session.execute(
select(SchemeVersionRecord)
.where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == scheme.current_version_number,
)
.with_for_update()
)
version = version_result.scalar_one_or_none()
if version is None:
_raise_current_version_inconsistent(
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
return version
async def create_scheme_from_upload(
*,
source_upload_id: str,
name: str,
normalized_elements_count: int,
normalized_seats_count: int,
normalized_groups_count: int,
normalized_sectors_count: int,
) -> str:
scheme_id = uuid4().hex
async with AsyncSessionLocal() as session:
row = SchemeRecord(
scheme_id=scheme_id,
source_upload_id=source_upload_id,
name=name,
status="draft",
current_version_number=1,
normalized_elements_count=normalized_elements_count,
normalized_seats_count=normalized_seats_count,
normalized_groups_count=normalized_groups_count,
normalized_sectors_count=normalized_sectors_count,
)
session.add(row)
await session.commit()
return scheme_id
async def create_scheme_from_upload_with_initial_version(
*,
source_upload_id: str,
name: str,
normalized_storage_path: str,
normalized_elements_count: int,
normalized_seats_count: int,
normalized_groups_count: int,
normalized_sectors_count: int,
display_svg_storage_path: str | None = None,
display_svg_status: str = "pending",
display_svg_generated_at=None,
) -> tuple[str, str]:
scheme_id = uuid4().hex
scheme_version_id = uuid4().hex
async with AsyncSessionLocal() as session:
scheme = SchemeRecord(
scheme_id=scheme_id,
source_upload_id=source_upload_id,
name=name,
status="draft",
current_version_number=1,
normalized_elements_count=normalized_elements_count,
normalized_seats_count=normalized_seats_count,
normalized_groups_count=normalized_groups_count,
normalized_sectors_count=normalized_sectors_count,
)
version = SchemeVersionRecord(
scheme_version_id=scheme_version_id,
scheme_id=scheme_id,
version_number=1,
status="draft",
normalized_storage_path=normalized_storage_path,
normalized_elements_count=normalized_elements_count,
normalized_seats_count=normalized_seats_count,
normalized_groups_count=normalized_groups_count,
normalized_sectors_count=normalized_sectors_count,
display_svg_storage_path=display_svg_storage_path,
display_svg_status=display_svg_status,
display_svg_generated_at=display_svg_generated_at,
)
session.add(scheme)
session.add(version)
await session.commit()
return scheme_id, scheme_version_id
async def list_scheme_records(limit: int = 50, offset: int = 0) -> list[SchemeRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeRecord)
.order_by(desc(SchemeRecord.created_at), desc(SchemeRecord.id))
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
async def count_scheme_records() -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(select(func.count()).select_from(SchemeRecord))
return int(result.scalar_one())
async def get_scheme_record_by_scheme_id(scheme_id: str) -> SchemeRecord:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeRecord).where(SchemeRecord.scheme_id == scheme_id)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
return row
async def publish_scheme(scheme_id: str) -> SchemeRecord:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
version = await _get_current_version_for_scheme(session, scheme)
scheme.status = "published"
scheme.published_at = func.now()
version.status = "published"
await session.refresh(scheme)
return scheme
async def unpublish_scheme(scheme_id: str) -> SchemeRecord:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
version = await _get_current_version_for_scheme(session, scheme)
scheme.status = "draft"
scheme.published_at = None
version.status = "draft"
await session.refresh(scheme)
return scheme
async def rollback_scheme_to_version(scheme_id: str, target_version_number: int) -> SchemeRecord:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_version_for_scheme(session, scheme)
target_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == target_version_number,
)
)
target_version = target_result.scalar_one_or_none()
if target_version is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Target scheme version not found",
)
current_version.status = "draft"
target_version.status = "draft"
scheme.current_version_number = target_version.version_number
scheme.status = "draft"
scheme.published_at = None
scheme.normalized_elements_count = target_version.normalized_elements_count
scheme.normalized_seats_count = target_version.normalized_seats_count
scheme.normalized_groups_count = target_version.normalized_groups_count
scheme.normalized_sectors_count = target_version.normalized_sectors_count
await session.refresh(scheme)
return scheme