fix(core): stabilize editor lifecycle, transactional versions, and runtime config

This commit is contained in:
greebo
2026-03-20 12:38:10 +03:00
parent 0f9c2a1cbd
commit 239b32a246
17 changed files with 1224 additions and 457 deletions

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Query, Request
from app.core.config import settings
from app.repositories.audit import create_audit_event
@@ -508,6 +508,7 @@ async def delete_draft_group(
@router.patch(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/draft/seats/records/{{seat_record_id}}", response_model=SeatPatchResponse)
async def patch_draft_seat(
request: Request,
scheme_id: str,
seat_record_id: str,
payload: SeatPatchRequest,
@@ -530,14 +531,20 @@ async def patch_draft_seat(
group_id=payload.group_id,
)
raw_json = await request.json()
update_data = {k: v for k, v in payload.model_dump(exclude_unset=True).items() if k in raw_json}
for field in ("seat_id", "sector_id", "group_id"):
if field in update_data and (update_data[field] is None or update_data[field] == ""):
from app.services.api_errors import raise_unprocessable
raise_unprocessable(
code="business_identifier_nullification_forbidden",
message=f"{field} cannot be nullified or explicitly cleared",
)
row = await update_scheme_version_seat_by_record_id(
scheme_version_id=version.scheme_version_id,
seat_record_id=seat_record_id,
seat_id=payload.seat_id,
sector_id=payload.sector_id,
group_id=payload.group_id,
row_label=payload.row_label,
seat_number=payload.seat_number,
**update_data,
)
await create_audit_event(
@@ -569,6 +576,7 @@ async def patch_draft_seat(
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/draft/seats/bulk", response_model=BulkSeatPatchResponse)
async def bulk_patch_draft_seats(
request: Request,
scheme_id: str,
payload: BulkSeatPatchRequest,
expected_scheme_version_id: str | None = Query(default=None),
@@ -579,7 +587,20 @@ async def bulk_patch_draft_seats(
expected_scheme_version_id=expected_scheme_version_id,
)
items = [item.model_dump() for item in payload.items]
raw_json = await request.json()
items = []
for i, item in enumerate(payload.items):
item_raw = raw_json.get("items", [])[i] if "items" in raw_json else {}
items.append({k: item.model_dump(exclude_unset=True).get(k) for k in item_raw})
for item in items:
for field in ("seat_id", "sector_id", "group_id"):
if field in item and (item[field] is None or item[field] == ""):
from app.services.api_errors import raise_unprocessable
raise_unprocessable(
code="business_identifier_nullification_forbidden",
message=f"{field} cannot be nullified or explicitly cleared",
)
await validate_bulk_seat_patch_uniqueness(
scheme_version_id=version.scheme_version_id,
items=items,
@@ -625,6 +646,7 @@ async def bulk_patch_draft_seats(
@router.patch(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/draft/sectors/records/{{sector_record_id}}", response_model=SectorPatchResponse)
async def patch_draft_sector(
request: Request,
scheme_id: str,
sector_record_id: str,
payload: SectorPatchRequest,
@@ -642,20 +664,28 @@ async def patch_draft_sector(
new_sector_id=payload.sector_id,
)
raw_json = await request.json()
update_data = {k: v for k, v in payload.model_dump(exclude_unset=True).items() if k in raw_json}
for field in ("sector_id",):
if field in update_data and (update_data[field] is None or update_data[field] == ""):
from app.services.api_errors import raise_unprocessable
raise_unprocessable(
code="business_identifier_nullification_forbidden",
message=f"{field} cannot be nullified or explicitly cleared",
)
row, old_sector_id = await update_scheme_version_sector_by_record_id(
scheme_version_id=version.scheme_version_id,
sector_record_id=sector_record_id,
sector_id=payload.sector_id,
name=payload.name,
)
cascaded_count = await cascade_update_seat_sector_reference(
scheme_version_id=version.scheme_version_id,
old_sector_id=old_sector_id,
new_sector_id=payload.sector_id,
)
repair_result = await repair_structure_references(
scheme_version_id=version.scheme_version_id,
**update_data,
)
cascaded_count = 0
if "sector_id" in update_data and update_data["sector_id"] and update_data["sector_id"] != old_sector_id:
cascaded_count = await cascade_update_seat_sector_reference(
scheme_version_id=version.scheme_version_id,
old_sector_id=old_sector_id,
new_sector_id=update_data["sector_id"],
)
await create_audit_event(
scheme_id=scheme.scheme_id,
@@ -668,7 +698,6 @@ async def patch_draft_sector(
"new_sector_id": payload.sector_id,
"name": payload.name,
"cascaded_seats_count": cascaded_count,
"repair_result": repair_result,
},
)
@@ -683,6 +712,7 @@ async def patch_draft_sector(
@router.patch(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/draft/groups/records/{{group_record_id}}", response_model=GroupPatchResponse)
async def patch_draft_group(
request: Request,
scheme_id: str,
group_record_id: str,
payload: GroupPatchRequest,
@@ -700,20 +730,28 @@ async def patch_draft_group(
new_group_id=payload.group_id,
)
raw_json = await request.json()
update_data = {k: v for k, v in payload.model_dump(exclude_unset=True).items() if k in raw_json}
for field in ("group_id",):
if field in update_data and (update_data[field] is None or update_data[field] == ""):
from app.services.api_errors import raise_unprocessable
raise_unprocessable(
code="business_identifier_nullification_forbidden",
message=f"{field} cannot be nullified or explicitly cleared",
)
row, old_group_id = await update_scheme_version_group_by_record_id(
scheme_version_id=version.scheme_version_id,
group_record_id=group_record_id,
group_id=payload.group_id,
name=payload.name,
)
cascaded_count = await cascade_update_seat_group_reference(
scheme_version_id=version.scheme_version_id,
old_group_id=old_group_id,
new_group_id=payload.group_id,
)
repair_result = await repair_structure_references(
scheme_version_id=version.scheme_version_id,
**update_data,
)
cascaded_count = 0
if "group_id" in update_data and update_data["group_id"] and update_data["group_id"] != old_group_id:
cascaded_count = await cascade_update_seat_group_reference(
scheme_version_id=version.scheme_version_id,
old_group_id=old_group_id,
new_group_id=update_data["group_id"],
)
await create_audit_event(
scheme_id=scheme.scheme_id,
@@ -726,7 +764,6 @@ async def patch_draft_group(
"new_group_id": payload.group_id,
"name": payload.name,
"cascaded_seats_count": cascaded_count,
"repair_result": repair_result,
},
)

View File

@@ -2,12 +2,10 @@ from fastapi import APIRouter, Depends, Query
from app.core.config import settings
from app.repositories.audit import create_audit_event
from app.repositories.scheme_groups import clone_scheme_version_groups
from app.repositories.scheme_seats import clone_scheme_version_seats
from app.repositories.scheme_sectors import clone_scheme_version_sectors
from app.repositories.scheme_versions import (
count_scheme_versions,
create_next_scheme_version_from_current,
create_next_scheme_version_from_current_checked,
ensure_draft_scheme_version_consistent,
get_current_scheme_version,
list_scheme_versions,
)
@@ -34,26 +32,12 @@ from app.schemas.scheme_versions import (
SchemeVersionListResponse,
)
from app.security.auth import require_api_key
from app.services.api_errors import raise_conflict
from app.services.publish_service import publish_current_draft_scheme
from app.services.scheme_validation import build_scheme_validation_report
router = APIRouter()
def _build_stale_current_version_detail(
*,
expected_scheme_version_id: str,
actual_scheme_version_id: str,
) -> dict:
return {
"code": "stale_current_version",
"message": "Current scheme version changed. Reload scheme state before creating a new version.",
"expected_scheme_version_id": expected_scheme_version_id,
"actual_scheme_version_id": actual_scheme_version_id,
}
@router.get(f"{settings.api_v1_prefix}/schemes", response_model=SchemeListResponse)
async def get_schemes(
limit: int = Query(default=50, ge=1, le=200),
@@ -155,36 +139,9 @@ async def create_next_scheme_version_endpoint(
expected_current_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
current_scheme = await get_scheme_record_by_scheme_id(scheme_id)
current_version = await get_current_scheme_version(
scheme_id=current_scheme.scheme_id,
current_version_number=current_scheme.current_version_number,
)
if (
expected_current_scheme_version_id
and expected_current_scheme_version_id != current_version.scheme_version_id
):
raise_conflict(
_build_stale_current_version_detail(
expected_scheme_version_id=expected_current_scheme_version_id,
actual_scheme_version_id=current_version.scheme_version_id,
)
)
new_version = await create_next_scheme_version_from_current(scheme_id)
await clone_scheme_version_sectors(
source_scheme_version_id=current_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_groups(
source_scheme_version_id=current_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_seats(
source_scheme_version_id=current_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
current_version, new_version = await create_next_scheme_version_from_current_checked(
scheme_id=scheme_id,
expected_current_scheme_version_id=expected_current_scheme_version_id,
)
await create_audit_event(
@@ -214,26 +171,14 @@ async def ensure_draft_scheme_version(
expected_current_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme = await get_scheme_record_by_scheme_id(scheme_id)
current_version = await get_current_scheme_version(
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
current_version, created, source_scheme_version_id = await ensure_draft_scheme_version_consistent(
scheme_id=scheme_id,
expected_current_scheme_version_id=expected_current_scheme_version_id,
)
if (
expected_current_scheme_version_id
and expected_current_scheme_version_id != current_version.scheme_version_id
):
raise_conflict(
_build_stale_current_version_detail(
expected_scheme_version_id=expected_current_scheme_version_id,
actual_scheme_version_id=current_version.scheme_version_id,
)
)
if scheme.status == "draft" and current_version.status == "draft":
if not created:
return EnsureDraftResponse(
scheme_id=scheme.scheme_id,
scheme_id=current_version.scheme_id,
scheme_version_id=current_version.scheme_version_id,
version_number=current_version.version_number,
status=current_version.status,
@@ -242,42 +187,27 @@ async def ensure_draft_scheme_version(
source_scheme_version_id=None,
)
new_version = await create_next_scheme_version_from_current(scheme_id)
await clone_scheme_version_sectors(
source_scheme_version_id=current_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_groups(
source_scheme_version_id=current_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_seats(
source_scheme_version_id=current_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
event_type="scheme.version.created",
object_type="scheme_version",
object_ref=new_version.scheme_version_id,
object_ref=current_version.scheme_version_id,
details={
"source_scheme_version_id": current_version.scheme_version_id,
"version_number": new_version.version_number,
"normalized_storage_path": new_version.normalized_storage_path,
"source_scheme_version_id": source_scheme_version_id,
"version_number": current_version.version_number,
"normalized_storage_path": current_version.normalized_storage_path,
"reason": "ensure_draft",
},
)
return EnsureDraftResponse(
scheme_id=new_version.scheme_id,
scheme_version_id=new_version.scheme_version_id,
version_number=new_version.version_number,
status=new_version.status,
normalized_storage_path=new_version.normalized_storage_path,
scheme_id=current_version.scheme_id,
scheme_version_id=current_version.scheme_version_id,
version_number=current_version.version_number,
status=current_version.status,
normalized_storage_path=current_version.normalized_storage_path,
created=True,
source_scheme_version_id=current_version.scheme_version_id,
source_scheme_version_id=source_scheme_version_id,
)

View File

@@ -10,8 +10,7 @@ from app.repositories.scheme_artifacts import create_scheme_artifact
from app.repositories.scheme_groups import replace_scheme_version_groups
from app.repositories.scheme_seats import replace_scheme_version_seats
from app.repositories.scheme_sectors import replace_scheme_version_sectors
from app.repositories.scheme_versions import create_initial_scheme_version
from app.repositories.schemes import create_scheme_from_upload
from app.repositories.schemes import create_scheme_from_upload_with_initial_version
from app.repositories.uploads import (
count_upload_records,
create_upload_record,
@@ -202,17 +201,9 @@ async def upload_scheme_svg(
processing_status="completed",
)
scheme_id = await create_scheme_from_upload(
scheme_id, scheme_version_id = await create_scheme_from_upload_with_initial_version(
source_upload_id=upload_id,
name=Path(filename).stem or filename,
normalized_elements_count=summary["elements_count"],
normalized_seats_count=summary["seats_count"],
normalized_groups_count=summary["groups_count"],
normalized_sectors_count=summary["sectors_count"],
)
scheme_version_id = await create_initial_scheme_version(
scheme_id=scheme_id,
normalized_storage_path=normalized_storage_path,
normalized_elements_count=summary["elements_count"],
normalized_seats_count=summary["seats_count"],

View File

@@ -1,29 +1,32 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
app_name: str = "svg-service"
app_env: str = "development"
app_port: int = 9020
api_v1_prefix: str = "/api/v1"
app_name: str = Field(..., validation_alias="APP_NAME")
app_env: str = Field(..., validation_alias="APP_ENV")
app_port: int = Field(..., validation_alias="BACKEND_PORT")
api_v1_prefix: str = Field(..., validation_alias="API_V1_PREFIX")
auth_header_name: str = "X-API-Key"
admin_api_key: str = "admin-local-dev-key"
viewer_api_key: str = "viewer-local-dev-key"
auth_header_name: str = Field(..., validation_alias="AUTH_HEADER_NAME")
api_keys_admin: str = Field(..., validation_alias="API_KEYS_ADMIN")
api_keys_operator: str = Field(..., validation_alias="API_KEYS_OPERATOR")
api_keys_viewer: str = Field(..., validation_alias="API_KEYS_VIEWER")
postgres_host: str = "postgres"
postgres_port: int = 5432
postgres_db: str = "svg_service"
postgres_user: str = "svg_service"
postgres_password: str = "svg_service_dev_password"
postgres_host: str = Field(..., validation_alias="POSTGRES_HOST")
postgres_port: int = Field(..., validation_alias="POSTGRES_PORT")
postgres_db: str = Field(..., validation_alias="POSTGRES_DB")
postgres_user: str = Field(..., validation_alias="POSTGRES_USER")
postgres_password: str = Field(..., validation_alias="POSTGRES_PASSWORD")
database_url_raw: str | None = Field(default=None, validation_alias="DATABASE_URL")
svg_max_file_size_bytes: int = 10 * 1024 * 1024
svg_max_elements: int = 25000
svg_max_file_size_bytes: int = Field(10 * 1024 * 1024, validation_alias="SVG_MAX_FILE_SIZE_BYTES")
svg_max_elements: int = Field(25000, validation_alias="SVG_MAX_ELEMENTS")
svg_allow_internal_use_references_only: bool = True
svg_forbid_foreign_object_v1: bool = True
svg_forbid_style_v1: bool = False
svg_forbid_image_v1: bool = True
svg_allow_internal_use_references_only: bool = Field(True, validation_alias="SVG_ALLOW_INTERNAL_USE_REFERENCES_ONLY")
svg_forbid_foreign_object_v1: bool = Field(True, validation_alias="SVG_FORBID_FOREIGN_OBJECT_V1")
svg_forbid_style_v1: bool = Field(False, validation_alias="SVG_FORBID_STYLE_V1")
svg_forbid_image_v1: bool = Field(True, validation_alias="SVG_FORBID_IMAGE_V1")
svg_display_enabled: bool = True
svg_display_mode: str = "passthrough"
@@ -34,7 +37,7 @@ class Settings(BaseSettings):
svg_display_force_viewbox: bool = True
svg_display_technical_text_patterns: str = "debug,tech,helper,tmp,service"
storage_root_dir: str = "/data"
storage_root_dir: str = Field(..., validation_alias="STORAGE_ROOT")
publish_preview_retention_per_variant: int = 2
publish_require_full_pricing_coverage: bool = False
@@ -45,16 +48,32 @@ class Settings(BaseSettings):
extra="ignore",
)
@model_validator(mode="after")
def validate_database_config(self) -> "Settings":
assembled_database_url = (
f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}"
f"@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
)
if self.database_url_raw and self.database_url_raw != assembled_database_url:
raise ValueError("DATABASE_URL must match POSTGRES_HOST/PORT/DB/USER/PASSWORD")
return self
@property
def admin_keys(self) -> set[str]:
return {item.strip() for item in self.admin_api_key.split(",") if item.strip()}
return {item.strip() for item in self.api_keys_admin.split(",") if item.strip()}
@property
def operator_keys(self) -> set[str]:
return {item.strip() for item in self.api_keys_operator.split(",") if item.strip()}
@property
def viewer_keys(self) -> set[str]:
return {item.strip() for item in self.viewer_api_key.split(",") if item.strip()}
return {item.strip() for item in self.api_keys_viewer.split(",") if item.strip()}
@property
def database_url(self) -> str:
if self.database_url_raw:
return self.database_url_raw
return (
f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}"
f"@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"

View File

@@ -8,6 +8,49 @@ from app.models.scheme_group import SchemeGroupRecord
from app.models.scheme_seat import SchemeSeatRecord
def _conflict(message: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"code": "group_uniqueness_violation",
"message": message,
},
)
async def _ensure_group_uniqueness(
*,
session,
scheme_version_id: str,
group_id: str | None,
element_id: str | None,
exclude_group_record_id: str | None = None,
) -> None:
if group_id:
stmt = select(SchemeGroupRecord).where(
SchemeGroupRecord.scheme_version_id == scheme_version_id,
SchemeGroupRecord.group_id == group_id,
)
if exclude_group_record_id:
stmt = stmt.where(SchemeGroupRecord.group_record_id != exclude_group_record_id)
existing = (await session.execute(stmt)).scalar_one_or_none()
if existing is not None:
raise _conflict(f"Group with group_id='{group_id}' already exists in current draft version")
if element_id:
stmt = select(SchemeGroupRecord).where(
SchemeGroupRecord.scheme_version_id == scheme_version_id,
SchemeGroupRecord.element_id == element_id,
)
if exclude_group_record_id:
stmt = stmt.where(SchemeGroupRecord.group_record_id != exclude_group_record_id)
existing = (await session.execute(stmt)).scalar_one_or_none()
if existing is not None:
raise _conflict(f"Group with element_id='{element_id}' already exists in current draft version")
async def replace_scheme_version_groups(
*,
scheme_id: str,
@@ -23,13 +66,29 @@ async def replace_scheme_version_groups(
for row in existing_rows:
await session.delete(row)
seen_group_ids: set[str] = set()
seen_element_ids: set[str] = set()
for item in groups:
group_id = item.get("group_id")
element_id = item.get("id")
if group_id:
if group_id in seen_group_ids:
raise _conflict(f"Duplicate group_id='{group_id}' in replacement payload")
seen_group_ids.add(group_id)
if element_id:
if element_id in seen_element_ids:
raise _conflict(f"Duplicate element_id='{element_id}' in replacement payload")
seen_element_ids.add(element_id)
row = SchemeGroupRecord(
group_record_id=item["group_record_id"] if "group_record_id" in item and item["group_record_id"] else uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=item.get("id"),
group_id=item.get("group_id"),
element_id=element_id,
group_id=group_id,
name=item.get("group_id"),
classes_raw=str(item.get("classes")),
)
@@ -44,26 +103,51 @@ async def clone_scheme_version_groups(
target_scheme_version_id: str,
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeGroupRecord).where(SchemeGroupRecord.scheme_version_id == source_scheme_version_id)
await clone_scheme_version_groups_in_session(
session=session,
source_scheme_version_id=source_scheme_version_id,
target_scheme_version_id=target_scheme_version_id,
)
rows = list(result.scalars().all())
for row in rows:
cloned = SchemeGroupRecord(
group_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
group_id=row.group_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
await session.commit()
async def clone_scheme_version_groups_in_session(
*,
session,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
result = await session.execute(
select(SchemeGroupRecord).where(SchemeGroupRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
seen_group_ids: set[str] = set()
seen_element_ids: set[str] = set()
for row in rows:
if row.group_id:
if row.group_id in seen_group_ids:
raise _conflict(f"Duplicate group_id='{row.group_id}' while cloning draft")
seen_group_ids.add(row.group_id)
if row.element_id:
if row.element_id in seen_element_ids:
raise _conflict(f"Duplicate element_id='{row.element_id}' while cloning draft")
seen_element_ids.add(row.element_id)
cloned = SchemeGroupRecord(
group_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
group_id=row.group_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
async def list_scheme_version_groups(scheme_version_id: str) -> list[SchemeGroupRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -78,8 +162,7 @@ async def update_scheme_version_group_by_record_id(
*,
scheme_version_id: str,
group_record_id: str,
group_id: str | None,
name: str | None,
**update_data,
) -> tuple[SchemeGroupRecord, str | None]:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -96,9 +179,20 @@ async def update_scheme_version_group_by_record_id(
detail="Group record not found in current draft version",
)
if "group_id" in update_data:
await _ensure_group_uniqueness(
session=session,
scheme_version_id=scheme_version_id,
group_id=update_data["group_id"],
element_id=row.element_id,
exclude_group_record_id=group_record_id,
)
old_group_id = row.group_id
row.group_id = group_id
row.name = name
if "group_id" in update_data:
row.group_id = update_data["group_id"]
if "name" in update_data:
row.name = update_data["name"]
await session.commit()
await session.refresh(row)
@@ -115,6 +209,13 @@ async def create_scheme_version_group(
classes_raw: str | None,
) -> SchemeGroupRecord:
async with AsyncSessionLocal() as session:
await _ensure_group_uniqueness(
session=session,
scheme_version_id=scheme_version_id,
group_id=group_id,
element_id=element_id,
)
row = SchemeGroupRecord(
group_record_id=uuid4().hex,
scheme_id=scheme_id,

View File

@@ -51,36 +51,48 @@ async def clone_scheme_version_seats(
target_scheme_version_id: str,
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == source_scheme_version_id)
await clone_scheme_version_seats_in_session(
session=session,
source_scheme_version_id=source_scheme_version_id,
target_scheme_version_id=target_scheme_version_id,
)
rows = list(result.scalars().all())
for row in rows:
cloned = SchemeSeatRecord(
seat_record_id=__import__("uuid").uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
seat_id=row.seat_id,
sector_id=row.sector_id,
group_id=row.group_id,
row_label=row.row_label,
seat_number=row.seat_number,
tag=row.tag,
classes_raw=row.classes_raw,
x=row.x,
y=row.y,
cx=row.cx,
cy=row.cy,
width=row.width,
height=row.height,
)
session.add(cloned)
await session.commit()
async def clone_scheme_version_seats_in_session(
*,
session,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
result = await session.execute(
select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
for row in rows:
cloned = SchemeSeatRecord(
seat_record_id=__import__("uuid").uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
seat_id=row.seat_id,
sector_id=row.sector_id,
group_id=row.group_id,
row_label=row.row_label,
seat_number=row.seat_number,
tag=row.tag,
classes_raw=row.classes_raw,
x=row.x,
y=row.y,
cx=row.cx,
cy=row.cy,
width=row.width,
height=row.height,
)
session.add(cloned)
async def list_scheme_version_seats(scheme_version_id: str) -> list[SchemeSeatRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -141,11 +153,7 @@ async def update_scheme_version_seat_by_record_id(
*,
scheme_version_id: str,
seat_record_id: str,
seat_id: str | None,
sector_id: str | None,
group_id: str | None,
row_label: str | None,
seat_number: str | None,
**update_data,
) -> SchemeSeatRecord:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -162,11 +170,16 @@ async def update_scheme_version_seat_by_record_id(
detail="Seat record not found in current draft version",
)
row.seat_id = seat_id
row.sector_id = sector_id
row.group_id = group_id
row.row_label = row_label
row.seat_number = seat_number
if "seat_id" in update_data:
row.seat_id = update_data["seat_id"]
if "sector_id" in update_data:
row.sector_id = update_data["sector_id"]
if "group_id" in update_data:
row.group_id = update_data["group_id"]
if "row_label" in update_data:
row.row_label = update_data["row_label"]
if "seat_number" in update_data:
row.seat_number = update_data["seat_number"]
await session.commit()
await session.refresh(row)
@@ -196,11 +209,16 @@ async def bulk_update_scheme_version_seats_by_record_id(
detail=f"Seat record not found in current draft version: {item['seat_record_id']}",
)
row.seat_id = item.get("seat_id")
row.sector_id = item.get("sector_id")
row.group_id = item.get("group_id")
row.row_label = item.get("row_label")
row.seat_number = item.get("seat_number")
if "seat_id" in item:
row.seat_id = item["seat_id"]
if "sector_id" in item:
row.sector_id = item["sector_id"]
if "group_id" in item:
row.group_id = item["group_id"]
if "row_label" in item:
row.row_label = item["row_label"]
if "seat_number" in item:
row.seat_number = item["seat_number"]
updated_rows.append(row)
await session.commit()

View File

@@ -8,6 +8,49 @@ from app.models.scheme_sector import SchemeSectorRecord
from app.models.scheme_seat import SchemeSeatRecord
def _conflict(message: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"code": "sector_uniqueness_violation",
"message": message,
},
)
async def _ensure_sector_uniqueness(
*,
session,
scheme_version_id: str,
sector_id: str | None,
element_id: str | None,
exclude_sector_record_id: str | None = None,
) -> None:
if sector_id:
stmt = select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_id == sector_id,
)
if exclude_sector_record_id:
stmt = stmt.where(SchemeSectorRecord.sector_record_id != exclude_sector_record_id)
existing = (await session.execute(stmt)).scalar_one_or_none()
if existing is not None:
raise _conflict(f"Sector with sector_id='{sector_id}' already exists in current draft version")
if element_id:
stmt = select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.element_id == element_id,
)
if exclude_sector_record_id:
stmt = stmt.where(SchemeSectorRecord.sector_record_id != exclude_sector_record_id)
existing = (await session.execute(stmt)).scalar_one_or_none()
if existing is not None:
raise _conflict(f"Sector with element_id='{element_id}' already exists in current draft version")
async def replace_scheme_version_sectors(
*,
scheme_id: str,
@@ -23,13 +66,29 @@ async def replace_scheme_version_sectors(
for row in existing_rows:
await session.delete(row)
seen_sector_ids: set[str] = set()
seen_element_ids: set[str] = set()
for item in sectors:
sector_id = item.get("sector_id")
element_id = item.get("id")
if sector_id:
if sector_id in seen_sector_ids:
raise _conflict(f"Duplicate sector_id='{sector_id}' in replacement payload")
seen_sector_ids.add(sector_id)
if element_id:
if element_id in seen_element_ids:
raise _conflict(f"Duplicate element_id='{element_id}' in replacement payload")
seen_element_ids.add(element_id)
row = SchemeSectorRecord(
sector_record_id=item["sector_record_id"] if "sector_record_id" in item and item["sector_record_id"] else uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=item.get("id"),
sector_id=item.get("sector_id"),
element_id=element_id,
sector_id=sector_id,
name=item.get("sector_id"),
classes_raw=str(item.get("classes")),
)
@@ -44,26 +103,51 @@ async def clone_scheme_version_sectors(
target_scheme_version_id: str,
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == source_scheme_version_id)
await clone_scheme_version_sectors_in_session(
session=session,
source_scheme_version_id=source_scheme_version_id,
target_scheme_version_id=target_scheme_version_id,
)
rows = list(result.scalars().all())
for row in rows:
cloned = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
sector_id=row.sector_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
await session.commit()
async def clone_scheme_version_sectors_in_session(
*,
session,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
result = await session.execute(
select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
seen_sector_ids: set[str] = set()
seen_element_ids: set[str] = set()
for row in rows:
if row.sector_id:
if row.sector_id in seen_sector_ids:
raise _conflict(f"Duplicate sector_id='{row.sector_id}' while cloning draft")
seen_sector_ids.add(row.sector_id)
if row.element_id:
if row.element_id in seen_element_ids:
raise _conflict(f"Duplicate element_id='{row.element_id}' while cloning draft")
seen_element_ids.add(row.element_id)
cloned = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
sector_id=row.sector_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
async def list_scheme_version_sectors(scheme_version_id: str) -> list[SchemeSectorRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -78,8 +162,7 @@ async def update_scheme_version_sector_by_record_id(
*,
scheme_version_id: str,
sector_record_id: str,
sector_id: str | None,
name: str | None,
**update_data,
) -> tuple[SchemeSectorRecord, str | None]:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -96,9 +179,20 @@ async def update_scheme_version_sector_by_record_id(
detail="Sector record not found in current draft version",
)
if "sector_id" in update_data:
await _ensure_sector_uniqueness(
session=session,
scheme_version_id=scheme_version_id,
sector_id=update_data["sector_id"],
element_id=row.element_id,
exclude_sector_record_id=sector_record_id,
)
old_sector_id = row.sector_id
row.sector_id = sector_id
row.name = name
if "sector_id" in update_data:
row.sector_id = update_data["sector_id"]
if "name" in update_data:
row.name = update_data["name"]
await session.commit()
await session.refresh(row)
@@ -115,6 +209,13 @@ async def create_scheme_version_sector(
classes_raw: str | None,
) -> SchemeSectorRecord:
async with AsyncSessionLocal() as session:
await _ensure_sector_uniqueness(
session=session,
scheme_version_id=scheme_version_id,
sector_id=sector_id,
element_id=element_id,
)
row = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=scheme_id,

View File

@@ -7,6 +7,125 @@ from sqlalchemy import asc, desc, func, select
from app.db.session import AsyncSessionLocal
from app.models.scheme import SchemeRecord
from app.models.scheme_version import SchemeVersionRecord
from app.repositories.scheme_groups import clone_scheme_version_groups_in_session
from app.repositories.scheme_seats import clone_scheme_version_seats_in_session
from app.repositories.scheme_sectors import clone_scheme_version_sectors_in_session
from app.services.api_errors import raise_conflict
def _raise_current_version_inconsistent(*, scheme_id: str, current_version_number: int) -> None:
raise_conflict(
code="current_version_inconsistent",
message="Scheme current version pointer is inconsistent with scheme_versions state.",
details={
"scheme_id": scheme_id,
"current_version_number": current_version_number,
},
)
def _raise_stale_current_version(*, expected_scheme_version_id: str, actual_scheme_version_id: str) -> None:
raise_conflict(
code="stale_current_version",
message="Current scheme version changed. Reload scheme state before creating a new version.",
details={
"expected_scheme_version_id": expected_scheme_version_id,
"actual_scheme_version_id": actual_scheme_version_id,
},
)
async def _get_scheme_for_update(session, scheme_id: str) -> SchemeRecord:
scheme_result = await session.execute(
select(SchemeRecord)
.where(SchemeRecord.scheme_id == scheme_id)
.with_for_update()
)
scheme = scheme_result.scalar_one_or_none()
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
return scheme
async def _get_current_scheme_version_for_update(
session,
*,
scheme_id: str,
current_version_number: int,
) -> SchemeVersionRecord:
current_result = await session.execute(
select(SchemeVersionRecord)
.where(
SchemeVersionRecord.scheme_id == scheme_id,
SchemeVersionRecord.version_number == current_version_number,
)
.with_for_update()
)
current_version = current_result.scalar_one_or_none()
if current_version is None:
_raise_current_version_inconsistent(
scheme_id=scheme_id,
current_version_number=current_version_number,
)
return current_version
async def _build_next_draft_version(
session,
*,
scheme: SchemeRecord,
source_version: SchemeVersionRecord,
) -> SchemeVersionRecord:
max_version_result = await session.execute(
select(func.coalesce(func.max(SchemeVersionRecord.version_number), 0)).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id
)
)
next_version_number = int(max_version_result.scalar_one()) + 1
new_version = SchemeVersionRecord(
scheme_version_id=uuid4().hex,
scheme_id=scheme.scheme_id,
version_number=next_version_number,
status="draft",
normalized_storage_path=source_version.normalized_storage_path,
normalized_elements_count=source_version.normalized_elements_count,
normalized_seats_count=source_version.normalized_seats_count,
normalized_groups_count=source_version.normalized_groups_count,
normalized_sectors_count=source_version.normalized_sectors_count,
display_svg_storage_path=source_version.display_svg_storage_path,
display_svg_status=source_version.display_svg_status,
display_svg_generated_at=source_version.display_svg_generated_at,
)
session.add(new_version)
await session.flush()
await clone_scheme_version_sectors_in_session(
session=session,
source_scheme_version_id=source_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_groups_in_session(
session=session,
source_scheme_version_id=source_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
await clone_scheme_version_seats_in_session(
session=session,
source_scheme_version_id=source_version.scheme_version_id,
target_scheme_version_id=new_version.scheme_version_id,
)
scheme.current_version_number = new_version.version_number
scheme.status = "draft"
scheme.published_at = None
scheme.normalized_elements_count = source_version.normalized_elements_count
scheme.normalized_seats_count = source_version.normalized_seats_count
scheme.normalized_groups_count = source_version.normalized_groups_count
scheme.normalized_sectors_count = source_version.normalized_sectors_count
return new_version
async def create_initial_scheme_version(
@@ -75,9 +194,9 @@ async def get_current_scheme_version(scheme_id: str, current_version_number: int
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Current scheme version not found",
_raise_current_version_inconsistent(
scheme_id=scheme_id,
current_version_number=current_version_number,
)
return row
@@ -113,57 +232,87 @@ async def update_scheme_version_display_artifact(
async def create_next_scheme_version_from_current(scheme_id: str) -> SchemeVersionRecord:
async with AsyncSessionLocal() as session:
scheme_result = await session.execute(
select(SchemeRecord).where(SchemeRecord.scheme_id == scheme_id)
)
scheme = scheme_result.scalar_one_or_none()
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_scheme_version_for_update(
session,
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
new_version = await _build_next_draft_version(
session,
scheme=scheme,
source_version=current_version,
)
current_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == scheme.current_version_number,
)
)
current_version = current_result.scalar_one_or_none()
if current_version is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Current scheme version not found",
)
next_version_number = current_version.version_number + 1
new_version = SchemeVersionRecord(
scheme_version_id=uuid4().hex,
scheme_id=scheme.scheme_id,
version_number=next_version_number,
status="draft",
normalized_storage_path=current_version.normalized_storage_path,
normalized_elements_count=current_version.normalized_elements_count,
normalized_seats_count=current_version.normalized_seats_count,
normalized_groups_count=current_version.normalized_groups_count,
normalized_sectors_count=current_version.normalized_sectors_count,
display_svg_storage_path=current_version.display_svg_storage_path,
display_svg_status=current_version.display_svg_status,
display_svg_generated_at=current_version.display_svg_generated_at,
)
session.add(new_version)
scheme.current_version_number = next_version_number
scheme.status = "draft"
scheme.published_at = None
scheme.normalized_elements_count = current_version.normalized_elements_count
scheme.normalized_seats_count = current_version.normalized_seats_count
scheme.normalized_groups_count = current_version.normalized_groups_count
scheme.normalized_sectors_count = current_version.normalized_sectors_count
await session.commit()
await session.refresh(new_version)
return new_version
async def create_next_scheme_version_from_current_checked(
*,
scheme_id: str,
expected_current_scheme_version_id: str | None = None,
) -> tuple[SchemeVersionRecord, SchemeVersionRecord]:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_scheme_version_for_update(
session,
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
if (
expected_current_scheme_version_id
and expected_current_scheme_version_id != current_version.scheme_version_id
):
_raise_stale_current_version(
expected_scheme_version_id=expected_current_scheme_version_id,
actual_scheme_version_id=current_version.scheme_version_id,
)
new_version = await _build_next_draft_version(
session,
scheme=scheme,
source_version=current_version,
)
await session.refresh(current_version)
await session.refresh(new_version)
return current_version, new_version
async def ensure_draft_scheme_version_consistent(
*,
scheme_id: str,
expected_current_scheme_version_id: str | None = None,
) -> tuple[SchemeVersionRecord, bool, str | None]:
async with AsyncSessionLocal() as session:
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_scheme_version_for_update(
session,
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
if (
expected_current_scheme_version_id
and expected_current_scheme_version_id != current_version.scheme_version_id
):
_raise_stale_current_version(
expected_scheme_version_id=expected_current_scheme_version_id,
actual_scheme_version_id=current_version.scheme_version_id,
)
if scheme.status == "draft" and current_version.status == "draft":
await session.refresh(current_version)
return current_version, False, None
new_version = await _build_next_draft_version(
session,
scheme=scheme,
source_version=current_version,
)
source_scheme_version_id = current_version.scheme_version_id
await session.refresh(new_version)
return new_version, True, source_scheme_version_id

View File

@@ -6,6 +6,51 @@ from sqlalchemy import desc, func, select
from app.db.session import AsyncSessionLocal
from app.models.scheme import SchemeRecord
from app.models.scheme_version import SchemeVersionRecord
from app.services.api_errors import raise_conflict
def _raise_current_version_inconsistent(*, scheme_id: str, current_version_number: int) -> None:
raise_conflict(
code="current_version_inconsistent",
message="Scheme current version pointer is inconsistent with scheme_versions state.",
details={
"scheme_id": scheme_id,
"current_version_number": current_version_number,
},
)
async def _get_scheme_for_update(session, scheme_id: str) -> SchemeRecord:
scheme_result = await session.execute(
select(SchemeRecord)
.where(SchemeRecord.scheme_id == scheme_id)
.with_for_update()
)
scheme = scheme_result.scalar_one_or_none()
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
return scheme
async def _get_current_version_for_scheme(session, scheme: SchemeRecord) -> SchemeVersionRecord:
version_result = await session.execute(
select(SchemeVersionRecord)
.where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == scheme.current_version_number,
)
.with_for_update()
)
version = version_result.scalar_one_or_none()
if version is None:
_raise_current_version_inconsistent(
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
)
return version
async def create_scheme_from_upload(
@@ -37,6 +82,55 @@ async def create_scheme_from_upload(
return scheme_id
async def create_scheme_from_upload_with_initial_version(
*,
source_upload_id: str,
name: str,
normalized_storage_path: str,
normalized_elements_count: int,
normalized_seats_count: int,
normalized_groups_count: int,
normalized_sectors_count: int,
display_svg_storage_path: str | None = None,
display_svg_status: str = "pending",
display_svg_generated_at=None,
) -> tuple[str, str]:
scheme_id = uuid4().hex
scheme_version_id = uuid4().hex
async with AsyncSessionLocal() as session:
scheme = SchemeRecord(
scheme_id=scheme_id,
source_upload_id=source_upload_id,
name=name,
status="draft",
current_version_number=1,
normalized_elements_count=normalized_elements_count,
normalized_seats_count=normalized_seats_count,
normalized_groups_count=normalized_groups_count,
normalized_sectors_count=normalized_sectors_count,
)
version = SchemeVersionRecord(
scheme_version_id=scheme_version_id,
scheme_id=scheme_id,
version_number=1,
status="draft",
normalized_storage_path=normalized_storage_path,
normalized_elements_count=normalized_elements_count,
normalized_seats_count=normalized_seats_count,
normalized_groups_count=normalized_groups_count,
normalized_sectors_count=normalized_sectors_count,
display_svg_storage_path=display_svg_storage_path,
display_svg_status=display_svg_status,
display_svg_generated_at=display_svg_generated_at,
)
session.add(scheme)
session.add(version)
await session.commit()
return scheme_id, scheme_version_id
async def list_scheme_records(limit: int = 50, offset: int = 0) -> list[SchemeRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
@@ -72,127 +166,60 @@ async def get_scheme_record_by_scheme_id(scheme_id: str) -> SchemeRecord:
async def publish_scheme(scheme_id: str) -> SchemeRecord:
async with AsyncSessionLocal() as session:
scheme_result = await session.execute(
select(SchemeRecord).where(SchemeRecord.scheme_id == scheme_id)
)
scheme = scheme_result.scalar_one_or_none()
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
version = await _get_current_version_for_scheme(session, scheme)
scheme.status = "published"
scheme.published_at = func.now()
version.status = "published"
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
version_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == scheme.current_version_number,
)
)
version = version_result.scalar_one_or_none()
if version is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Current scheme version not found",
)
scheme.status = "published"
scheme.published_at = func.now()
version.status = "published"
await session.commit()
await session.refresh(scheme)
return scheme
async def unpublish_scheme(scheme_id: str) -> SchemeRecord:
async with AsyncSessionLocal() as session:
scheme_result = await session.execute(
select(SchemeRecord).where(SchemeRecord.scheme_id == scheme_id)
)
scheme = scheme_result.scalar_one_or_none()
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
version = await _get_current_version_for_scheme(session, scheme)
scheme.status = "draft"
scheme.published_at = None
version.status = "draft"
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
)
version_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == scheme.current_version_number,
)
)
version = version_result.scalar_one_or_none()
if version is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Current scheme version not found",
)
scheme.status = "draft"
scheme.published_at = None
version.status = "draft"
await session.commit()
await session.refresh(scheme)
return scheme
async def rollback_scheme_to_version(scheme_id: str, target_version_number: int) -> SchemeRecord:
async with AsyncSessionLocal() as session:
scheme_result = await session.execute(
select(SchemeRecord).where(SchemeRecord.scheme_id == scheme_id)
)
scheme = scheme_result.scalar_one_or_none()
async with session.begin():
scheme = await _get_scheme_for_update(session, scheme_id)
current_version = await _get_current_version_for_scheme(session, scheme)
if scheme is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scheme not found",
target_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == target_version_number,
)
)
target_version = target_result.scalar_one_or_none()
target_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == target_version_number,
)
)
target_version = target_result.scalar_one_or_none()
if target_version is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Target scheme version not found",
)
if target_version is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Target scheme version not found",
)
current_result = await session.execute(
select(SchemeVersionRecord).where(
SchemeVersionRecord.scheme_id == scheme.scheme_id,
SchemeVersionRecord.version_number == scheme.current_version_number,
)
)
current_version = current_result.scalar_one_or_none()
if current_version is not None:
current_version.status = "draft"
target_version.status = "draft"
scheme.current_version_number = target_version.version_number
scheme.status = "draft"
scheme.published_at = None
target_version.status = "draft"
scheme.current_version_number = target_version.version_number
scheme.status = "draft"
scheme.published_at = None
scheme.normalized_elements_count = target_version.normalized_elements_count
scheme.normalized_seats_count = target_version.normalized_seats_count
scheme.normalized_groups_count = target_version.normalized_groups_count
scheme.normalized_sectors_count = target_version.normalized_sectors_count
scheme.normalized_elements_count = target_version.normalized_elements_count
scheme.normalized_seats_count = target_version.normalized_seats_count
scheme.normalized_groups_count = target_version.normalized_groups_count
scheme.normalized_sectors_count = target_version.normalized_sectors_count
await session.commit()
await session.refresh(scheme)
return scheme

View File

@@ -14,7 +14,9 @@ def resolve_role(api_key: str) -> str | None:
return None
async def require_api_key(x_api_key: str | None = Header(default=None, alias="X-API-Key")) -> str:
async def require_api_key(
x_api_key: str | None = Header(default=None, alias=settings.auth_header_name),
) -> str:
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,

View File

@@ -7,13 +7,21 @@ from app.services.api_errors import raise_unprocessable
def _raise_uniqueness_error(message: str, detail: dict | None = None) -> None:
payload = detail or {"code": "editor_uniqueness_error", "message": message}
raise_unprocessable(**payload)
if detail:
code = detail.pop("code", "editor_uniqueness_error")
msg = detail.pop("message", message)
raise_unprocessable(code=code, message=msg, details=detail)
else:
raise_unprocessable(code="editor_uniqueness_error", message=message)
def _raise_reference_error(message: str, detail: dict | None = None) -> None:
payload = detail or {"code": "editor_reference_error", "message": message}
raise_unprocessable(**payload)
if detail:
code = detail.pop("code", "editor_reference_error")
msg = detail.pop("message", message)
raise_unprocessable(code=code, message=msg, details=detail)
else:
raise_unprocessable(code="editor_reference_error", message=message)
async def validate_single_seat_patch_uniqueness(

View File

@@ -6,9 +6,8 @@ from app.repositories.scheme_sectors import list_scheme_version_sectors
from app.services.baseline_selector import select_baseline_scheme_version
def _serialize_sector(row) -> dict:
def _sector_compare_value(row) -> dict:
return {
"sector_record_id": row.sector_record_id,
"element_id": row.element_id,
"sector_id": row.sector_id,
"name": row.name,
@@ -16,9 +15,14 @@ def _serialize_sector(row) -> dict:
}
def _serialize_group(row) -> dict:
def _sector_response_value(row) -> dict:
payload = _sector_compare_value(row)
payload["sector_record_id"] = row.sector_record_id
return payload
def _group_compare_value(row) -> dict:
return {
"group_record_id": row.group_record_id,
"element_id": row.element_id,
"group_id": row.group_id,
"name": row.name,
@@ -26,9 +30,14 @@ def _serialize_group(row) -> dict:
}
def _serialize_seat(row) -> dict:
def _group_response_value(row) -> dict:
payload = _group_compare_value(row)
payload["group_record_id"] = row.group_record_id
return payload
def _seat_compare_value(row) -> dict:
return {
"seat_record_id": row.seat_record_id,
"element_id": row.element_id,
"seat_id": row.seat_id,
"sector_id": row.sector_id,
@@ -38,19 +47,33 @@ def _serialize_seat(row) -> dict:
}
def _build_diff(before_map: dict, after_map: dict) -> list[dict]:
keys = sorted(set(before_map.keys()) | set(after_map.keys()))
def _seat_response_value(row) -> dict:
payload = _seat_compare_value(row)
payload["seat_record_id"] = row.seat_record_id
return payload
def _build_diff(
*,
before_compare_map: dict,
after_compare_map: dict,
before_payload_map: dict,
after_payload_map: dict,
) -> list[dict]:
keys = sorted(set(before_payload_map.keys()) | set(after_payload_map.keys()))
result: list[dict] = []
for key in keys:
before = before_map.get(key)
after = after_map.get(key)
before_compare = before_compare_map.get(key)
after_compare = after_compare_map.get(key)
before_payload = before_payload_map.get(key)
after_payload = after_payload_map.get(key)
if before is None and after is not None:
if before_compare is None and after_compare is not None:
status = "added"
elif before is not None and after is None:
elif before_compare is not None and after_compare is None:
status = "removed"
elif before != after:
elif before_compare != after_compare:
status = "changed"
else:
status = "unchanged"
@@ -59,13 +82,22 @@ def _build_diff(before_map: dict, after_map: dict) -> list[dict]:
{
"key": key,
"status": status,
"before": before,
"after": after,
"before": before_payload,
"after": after_payload,
}
)
return result
def _sector_key(row) -> str:
return row.sector_id if row.sector_id else (row.element_id if row.element_id else row.sector_record_id)
def _group_key(row) -> str:
return row.group_id if row.group_id else (row.element_id if row.element_id else row.group_record_id)
def _seat_key(row) -> str:
return row.seat_id if row.seat_id else (row.element_id if row.element_id else row.seat_record_id)
async def build_structure_diff(
*,
scheme_id: str,
@@ -83,32 +115,68 @@ async def build_structure_diff(
draft_seats = await list_scheme_version_seats(draft_scheme_version_id)
if baseline is None:
baseline_sector_map = {}
baseline_group_map = {}
baseline_seat_map = {}
baseline_sector_compare_map = {}
baseline_group_compare_map = {}
baseline_seat_compare_map = {}
baseline_sector_payload_map = {}
baseline_group_payload_map = {}
baseline_seat_payload_map = {}
baseline_scheme_version_id = None
else:
baseline_scheme_version_id = baseline.scheme_version_id
baseline_sector_map = {
row.sector_record_id: _serialize_sector(row)
for row in await list_scheme_version_sectors(baseline.scheme_version_id)
baseline_sectors = await list_scheme_version_sectors(baseline.scheme_version_id)
baseline_groups = await list_scheme_version_groups(baseline.scheme_version_id)
baseline_seats = await list_scheme_version_seats(baseline.scheme_version_id)
baseline_sector_compare_map = {
_sector_key(row): _sector_compare_value(row)
for row in baseline_sectors
}
baseline_group_map = {
row.group_record_id: _serialize_group(row)
for row in await list_scheme_version_groups(baseline.scheme_version_id)
baseline_sector_payload_map = {
_sector_key(row): _sector_response_value(row)
for row in baseline_sectors
}
baseline_seat_map = {
row.seat_record_id: _serialize_seat(row)
for row in await list_scheme_version_seats(baseline.scheme_version_id)
baseline_group_compare_map = {
_group_key(row): _group_compare_value(row)
for row in baseline_groups
}
baseline_group_payload_map = {
_group_key(row): _group_response_value(row)
for row in baseline_groups
}
baseline_seat_compare_map = {
_seat_key(row): _seat_compare_value(row)
for row in baseline_seats
}
baseline_seat_payload_map = {
_seat_key(row): _seat_response_value(row)
for row in baseline_seats
}
draft_sector_map = {row.sector_record_id: _serialize_sector(row) for row in draft_sectors}
draft_group_map = {row.group_record_id: _serialize_group(row) for row in draft_groups}
draft_seat_map = {row.seat_record_id: _serialize_seat(row) for row in draft_seats}
draft_sector_compare_map = {_sector_key(row): _sector_compare_value(row) for row in draft_sectors}
draft_sector_payload_map = {_sector_key(row): _sector_response_value(row) for row in draft_sectors}
draft_group_compare_map = {_group_key(row): _group_compare_value(row) for row in draft_groups}
draft_group_payload_map = {_group_key(row): _group_response_value(row) for row in draft_groups}
draft_seat_compare_map = {_seat_key(row): _seat_compare_value(row) for row in draft_seats}
draft_seat_payload_map = {_seat_key(row): _seat_response_value(row) for row in draft_seats}
sector_diff = _build_diff(baseline_sector_map, draft_sector_map)
group_diff = _build_diff(baseline_group_map, draft_group_map)
seat_diff = _build_diff(baseline_seat_map, draft_seat_map)
sector_diff = _build_diff(
before_compare_map=baseline_sector_compare_map,
after_compare_map=draft_sector_compare_map,
before_payload_map=baseline_sector_payload_map,
after_payload_map=draft_sector_payload_map,
)
group_diff = _build_diff(
before_compare_map=baseline_group_compare_map,
after_compare_map=draft_group_compare_map,
before_payload_map=baseline_group_payload_map,
after_payload_map=draft_group_payload_map,
)
seat_diff = _build_diff(
before_compare_map=baseline_seat_compare_map,
after_compare_map=draft_seat_compare_map,
before_payload_map=baseline_seat_payload_map,
after_payload_map=draft_seat_payload_map,
)
return {
"baseline_scheme_version_id": baseline_scheme_version_id,