feat: add optimistic concurrency guards for draft editor, pricing and publish flows

add optimistic concurrency guards via expected scheme version id

protect draft editor, pricing snapshot, remap and publish flows from stale mutations
protect version creation from stale current version state

keep backward compatibility with optional query guards

verify 409 conflict behavior for stale clients and 200 for valid flows
This commit is contained in:
greebo
2026-03-19 18:58:03 +03:00
parent 76710372c4
commit c7c9184a71
8 changed files with 410 additions and 70 deletions

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from app.core.config import settings
from app.repositories.audit import create_audit_event
@@ -152,9 +152,13 @@ async def get_draft_compare_preview(
async def create_draft_sector(
scheme_id: str,
payload: CreateSectorRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
row = await create_scheme_version_sector(
scheme_id=scheme.scheme_id,
@@ -191,9 +195,13 @@ async def create_draft_sector(
async def create_draft_group(
scheme_id: str,
payload: CreateGroupRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
row = await create_scheme_version_group(
scheme_id=scheme.scheme_id,
@@ -230,9 +238,13 @@ async def create_draft_group(
async def delete_draft_sector(
scheme_id: str,
sector_record_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await delete_scheme_version_sector_by_record_id(
scheme_version_id=version.scheme_version_id,
@@ -259,9 +271,13 @@ async def delete_draft_sector(
async def delete_draft_group(
scheme_id: str,
group_record_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await delete_scheme_version_group_by_record_id(
scheme_version_id=version.scheme_version_id,
@@ -289,9 +305,13 @@ async def patch_draft_seat(
scheme_id: str,
seat_record_id: str,
payload: SeatPatchRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await validate_single_seat_patch_uniqueness(
scheme_version_id=version.scheme_version_id,
@@ -340,9 +360,13 @@ async def patch_draft_seat(
async def bulk_patch_draft_seats(
scheme_id: str,
payload: BulkSeatPatchRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
items = [item.model_dump() for item in payload.items]
await validate_bulk_seat_patch_uniqueness(
@@ -389,9 +413,13 @@ async def patch_draft_sector(
scheme_id: str,
sector_record_id: str,
payload: SectorPatchRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await validate_sector_patch_uniqueness(
scheme_version_id=version.scheme_version_id,
@@ -421,7 +449,8 @@ async def patch_draft_sector(
object_ref=sector_record_id,
details={
"scheme_version_id": version.scheme_version_id,
"sector_id": payload.sector_id,
"old_sector_id": old_sector_id,
"new_sector_id": payload.sector_id,
"name": payload.name,
"cascaded_seats_count": cascaded_count,
"repair_result": repair_result,
@@ -442,9 +471,13 @@ async def patch_draft_group(
scheme_id: str,
group_record_id: str,
payload: GroupPatchRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await validate_group_patch_uniqueness(
scheme_version_id=version.scheme_version_id,
@@ -474,7 +507,8 @@ async def patch_draft_group(
object_ref=group_record_id,
details={
"scheme_version_id": version.scheme_version_id,
"group_id": payload.group_id,
"old_group_id": old_group_id,
"new_group_id": payload.group_id,
"name": payload.name,
"cascaded_seats_count": cascaded_count,
"repair_result": repair_result,
@@ -493,9 +527,14 @@ async def patch_draft_group(
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/draft/repair-references", response_model=RepairReferencesResponse)
async def repair_draft_references(
scheme_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
result = await repair_structure_references(
scheme_version_id=version.scheme_version_id,
)
@@ -511,7 +550,7 @@ async def repair_draft_references(
return RepairReferencesResponse(
scheme_id=scheme.scheme_id,
scheme_version_id=version.scheme_version_id,
repaired_sector_refs_count=result["repaired_sector_refs_count"],
repaired_group_refs_count=result["repaired_group_refs_count"],
details=result["details"],
repaired_sector_refs_count=result.get("repaired_sector_refs_count", 0),
repaired_group_refs_count=result.get("repaired_group_refs_count", 0),
details=result,
)

View File

@@ -1,6 +1,6 @@
from decimal import Decimal
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.config import settings
from app.repositories.audit import create_audit_event
@@ -16,7 +16,6 @@ from app.repositories.pricing import (
)
from app.repositories.scheme_version_pricing import replace_scheme_version_pricing_snapshot
from app.repositories.schemes import get_scheme_record_by_scheme_id
from app.repositories.scheme_versions import get_current_scheme_version
from app.schemas.pricing import (
PriceRuleCreateRequest,
PriceRuleItem,
@@ -27,18 +26,20 @@ from app.schemas.pricing import (
PricingCategoryUpdateRequest,
)
from app.security.auth import require_api_key
from app.services.draft_guard import get_current_draft_context
router = APIRouter()
async def _refresh_current_draft_snapshot_if_possible(scheme_id: str) -> dict | None:
scheme = await get_scheme_record_by_scheme_id(scheme_id)
version = await get_current_scheme_version(
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
async def _refresh_current_draft_snapshot_if_possible(
*,
scheme_id: str,
expected_scheme_version_id: str | None = None,
) -> dict | None:
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
if scheme.status != "draft" or version.status != "draft":
return None
return await replace_scheme_version_pricing_snapshot(
scheme_id=scheme.scheme_id,
@@ -82,26 +83,41 @@ async def get_pricing_bundle(scheme_id: str, role: str = Depends(require_api_key
async def create_pricing_category_endpoint(
scheme_id: str,
payload: PricingCategoryCreateRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
pricing_category_id = await create_pricing_category(
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
pricing_category_id = await create_pricing_category(
scheme_id=scheme.scheme_id,
name=payload.name,
code=payload.code,
)
snapshot = await _refresh_current_draft_snapshot_if_possible(scheme_id)
snapshot = await _refresh_current_draft_snapshot_if_possible(
scheme_id=scheme.scheme_id,
expected_scheme_version_id=version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
event_type="pricing.category.created",
object_type="pricing_category",
object_ref=pricing_category_id,
details={"name": payload.name, "code": payload.code, "snapshot": snapshot},
details={
"name": payload.name,
"code": payload.code,
"scheme_version_id": version.scheme_version_id,
"snapshot": snapshot,
},
)
return {
"pricing_category_id": pricing_category_id,
"scheme_id": scheme_id,
"scheme_id": scheme.scheme_id,
"scheme_version_id": version.scheme_version_id,
"name": payload.name,
"code": payload.code,
}
@@ -112,27 +128,42 @@ async def update_pricing_category_endpoint(
scheme_id: str,
pricing_category_id: str,
payload: PricingCategoryUpdateRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
row = await update_pricing_category(
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
row = await update_pricing_category(
scheme_id=scheme.scheme_id,
pricing_category_id=pricing_category_id,
name=payload.name,
code=payload.code,
)
snapshot = await _refresh_current_draft_snapshot_if_possible(scheme_id)
snapshot = await _refresh_current_draft_snapshot_if_possible(
scheme_id=scheme.scheme_id,
expected_scheme_version_id=version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
event_type="pricing.category.updated",
object_type="pricing_category",
object_ref=pricing_category_id,
details={"name": payload.name, "code": payload.code, "snapshot": snapshot},
details={
"name": payload.name,
"code": payload.code,
"scheme_version_id": version.scheme_version_id,
"snapshot": snapshot,
},
)
return {
"pricing_category_id": row.pricing_category_id,
"scheme_id": row.scheme_id,
"scheme_version_id": version.scheme_version_id,
"name": row.name,
"code": row.code,
}
@@ -142,31 +173,53 @@ async def update_pricing_category_endpoint(
async def delete_pricing_category_endpoint(
scheme_id: str,
pricing_category_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
await delete_pricing_category(
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await delete_pricing_category(
scheme_id=scheme.scheme_id,
pricing_category_id=pricing_category_id,
)
snapshot = await _refresh_current_draft_snapshot_if_possible(scheme_id)
snapshot = await _refresh_current_draft_snapshot_if_possible(
scheme_id=scheme.scheme_id,
expected_scheme_version_id=version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
event_type="pricing.category.deleted",
object_type="pricing_category",
object_ref=pricing_category_id,
details={"snapshot": snapshot},
details={
"scheme_version_id": version.scheme_version_id,
"snapshot": snapshot,
},
)
return {"deleted": True, "pricing_category_id": pricing_category_id}
return {
"deleted": True,
"pricing_category_id": pricing_category_id,
"scheme_version_id": version.scheme_version_id,
}
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/pricing/rules")
async def create_price_rule_endpoint(
scheme_id: str,
payload: PriceRuleCreateRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
try:
amount = Decimal(payload.amount)
except Exception:
@@ -176,17 +229,20 @@ async def create_price_rule_endpoint(
)
price_rule_id = await create_price_rule(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
pricing_category_id=payload.pricing_category_id,
target_type=payload.target_type,
target_ref=payload.target_ref,
amount=amount,
currency=payload.currency,
)
snapshot = await _refresh_current_draft_snapshot_if_possible(scheme_id)
snapshot = await _refresh_current_draft_snapshot_if_possible(
scheme_id=scheme.scheme_id,
expected_scheme_version_id=version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
event_type="pricing.rule.created",
object_type="price_rule",
object_ref=price_rule_id,
@@ -196,13 +252,15 @@ async def create_price_rule_endpoint(
"target_ref": payload.target_ref,
"amount": payload.amount,
"currency": payload.currency,
"scheme_version_id": version.scheme_version_id,
"snapshot": snapshot,
},
)
return {
"price_rule_id": price_rule_id,
"scheme_id": scheme_id,
"scheme_id": scheme.scheme_id,
"scheme_version_id": version.scheme_version_id,
"pricing_category_id": payload.pricing_category_id,
"target_type": payload.target_type,
"target_ref": payload.target_ref,
@@ -216,8 +274,14 @@ async def update_price_rule_endpoint(
scheme_id: str,
price_rule_id: str,
payload: PriceRuleUpdateRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
try:
amount = Decimal(payload.amount)
except Exception:
@@ -227,7 +291,7 @@ async def update_price_rule_endpoint(
)
row = await update_price_rule(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
price_rule_id=price_rule_id,
pricing_category_id=payload.pricing_category_id,
target_type=payload.target_type,
@@ -235,10 +299,13 @@ async def update_price_rule_endpoint(
amount=amount,
currency=payload.currency,
)
snapshot = await _refresh_current_draft_snapshot_if_possible(scheme_id)
snapshot = await _refresh_current_draft_snapshot_if_possible(
scheme_id=scheme.scheme_id,
expected_scheme_version_id=version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
event_type="pricing.rule.updated",
object_type="price_rule",
object_ref=price_rule_id,
@@ -248,6 +315,7 @@ async def update_price_rule_endpoint(
"target_ref": payload.target_ref,
"amount": payload.amount,
"currency": payload.currency,
"scheme_version_id": version.scheme_version_id,
"snapshot": snapshot,
},
)
@@ -255,6 +323,7 @@ async def update_price_rule_endpoint(
return {
"price_rule_id": row.price_rule_id,
"scheme_id": row.scheme_id,
"scheme_version_id": version.scheme_version_id,
"pricing_category_id": row.pricing_category_id,
"target_type": row.target_type,
"target_ref": row.target_ref,
@@ -267,20 +336,36 @@ async def update_price_rule_endpoint(
async def delete_price_rule_endpoint(
scheme_id: str,
price_rule_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
await delete_price_rule(
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
await delete_price_rule(
scheme_id=scheme.scheme_id,
price_rule_id=price_rule_id,
)
snapshot = await _refresh_current_draft_snapshot_if_possible(scheme_id)
snapshot = await _refresh_current_draft_snapshot_if_possible(
scheme_id=scheme.scheme_id,
expected_scheme_version_id=version.scheme_version_id,
)
await create_audit_event(
scheme_id=scheme_id,
scheme_id=scheme.scheme_id,
event_type="pricing.rule.deleted",
object_type="price_rule",
object_ref=price_rule_id,
details={"snapshot": snapshot},
details={
"scheme_version_id": version.scheme_version_id,
"snapshot": snapshot,
},
)
return {"deleted": True, "price_rule_id": price_rule_id}
return {
"deleted": True,
"price_rule_id": price_rule_id,
"scheme_version_id": version.scheme_version_id,
}

View File

@@ -22,9 +22,13 @@ router = APIRouter()
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/draft/pricing/snapshot")
async def create_draft_pricing_snapshot(
scheme_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
result = await replace_scheme_version_pricing_snapshot(
scheme_id=scheme.scheme_id,
scheme_version_id=version.scheme_version_id,
@@ -50,9 +54,13 @@ async def get_publish_preview(
scheme_id: str,
baseline_scheme_version_id: str | None = Query(default=None),
refresh: bool = Query(default=False),
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
bundle = await get_or_build_publish_preview_bundle(
scheme_id=scheme.scheme_id,
scheme_version_id=version.scheme_version_id,
@@ -74,9 +82,13 @@ async def get_publish_preview(
async def preview_draft_remap(
scheme_id: str,
payload: RemapPreviewRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
items = await preview_remap(
scheme_version_id=version.scheme_version_id,
seat_record_ids=payload.seat_record_ids,
@@ -98,9 +110,13 @@ async def preview_draft_remap(
async def apply_draft_remap(
scheme_id: str,
payload: RemapApplyRequest,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
scheme, version = await get_current_draft_context(scheme_id)
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
items = await apply_remap(
scheme_version_id=version.scheme_version_id,
seat_record_ids=payload.seat_record_ids,

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.config import settings
from app.repositories.audit import create_audit_event
@@ -137,6 +137,7 @@ async def get_scheme_versions(
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/versions", response_model=SchemeVersionCreateResponse)
async def create_next_scheme_version_endpoint(
scheme_id: str,
expected_current_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
current_scheme = await get_scheme_record_by_scheme_id(scheme_id)
@@ -145,6 +146,20 @@ async def create_next_scheme_version_endpoint(
current_version_number=current_scheme.current_version_number,
)
if (
expected_current_scheme_version_id is not None
and expected_current_scheme_version_id != current_version.scheme_version_id
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"code": "stale_current_version",
"message": "Current scheme version changed. Reload scheme state before creating a new version.",
"expected_scheme_version_id": expected_current_scheme_version_id,
"actual_scheme_version_id": current_version.scheme_version_id,
},
)
new_version = await create_next_scheme_version_from_current(scheme_id)
await clone_scheme_version_sectors(
@@ -200,8 +215,15 @@ async def get_publish_validation(scheme_id: str, role: str = Depends(require_api
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/publish")
async def publish_scheme_endpoint(scheme_id: str, role: str = Depends(require_api_key)):
return await publish_current_draft_scheme(scheme_id=scheme_id)
async def publish_scheme_endpoint(
scheme_id: str,
expected_scheme_version_id: str | None = Query(default=None),
role: str = Depends(require_api_key),
):
return await publish_current_draft_scheme(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
@router.post(f"{settings.api_v1_prefix}/schemes/{{scheme_id}}/unpublish", response_model=SchemePublishResponse)

View File

@@ -4,7 +4,30 @@ from app.repositories.scheme_versions import get_current_scheme_version
from app.repositories.schemes import get_scheme_record_by_scheme_id
async def get_current_draft_context(scheme_id: str):
def ensure_expected_scheme_version_id(
*,
actual_scheme_version_id: str,
expected_scheme_version_id: str | None,
) -> None:
if expected_scheme_version_id is None:
return
if expected_scheme_version_id != actual_scheme_version_id:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"code": "stale_draft_version",
"message": "Draft scheme version is stale. Reload current draft state before applying mutation.",
"expected_scheme_version_id": expected_scheme_version_id,
"actual_scheme_version_id": actual_scheme_version_id,
},
)
async def get_current_draft_context(
scheme_id: str,
expected_scheme_version_id: str | None = None,
):
scheme = await get_scheme_record_by_scheme_id(scheme_id)
version = await get_current_scheme_version(
scheme_id=scheme.scheme_id,
@@ -17,4 +40,9 @@ async def get_current_draft_context(scheme_id: str):
detail="Current scheme version is not editable because it is not in draft state",
)
ensure_expected_scheme_version_id(
actual_scheme_version_id=version.scheme_version_id,
expected_scheme_version_id=expected_scheme_version_id,
)
return scheme, version

View File

@@ -2,27 +2,21 @@ from fastapi import HTTPException, status
from app.repositories.audit import create_audit_event
from app.repositories.scheme_version_pricing import replace_scheme_version_pricing_snapshot
from app.repositories.scheme_versions import get_current_scheme_version
from app.repositories.schemes import get_scheme_record_by_scheme_id, publish_scheme
from app.services.draft_guard import get_current_draft_context
from app.repositories.schemes import publish_scheme
from app.services.scheme_validation import build_scheme_validation_report
async def publish_current_draft_scheme(
*,
scheme_id: str,
expected_scheme_version_id: str | None = None,
) -> dict:
scheme = await get_scheme_record_by_scheme_id(scheme_id)
version = await get_current_scheme_version(
scheme_id=scheme.scheme_id,
current_version_number=scheme.current_version_number,
scheme, version = await get_current_draft_context(
scheme_id=scheme_id,
expected_scheme_version_id=expected_scheme_version_id,
)
if scheme.status != "draft" or version.status != "draft":
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Current scheme version is not publishable because it is not in draft state",
)
validation = await build_scheme_validation_report(
scheme_id=scheme.scheme_id,
scheme_version_id=version.scheme_version_id,