Implement display artifacts, pricing integrity, draft base and publish preview bundle

This commit is contained in:
greebo
2026-03-19 17:58:17 +03:00
parent 85fb2f4bb9
commit c91c5abf15
35 changed files with 3283 additions and 302 deletions

View File

@@ -9,6 +9,34 @@ from app.models.price_rule import PriceRuleRecord
from app.models.pricing_category import PricingCategoryRecord
async def _ensure_unique_price_rule_target(
*,
session,
scheme_id: str,
target_type: str,
target_ref: str,
exclude_price_rule_id: str | None = None,
) -> None:
stmt = select(PriceRuleRecord).where(
PriceRuleRecord.scheme_id == scheme_id,
PriceRuleRecord.target_type == target_type,
PriceRuleRecord.target_ref == target_ref,
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row is None:
return
if exclude_price_rule_id and row.price_rule_id == exclude_price_rule_id:
return
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Pricing rule already exists for this target",
)
async def create_pricing_category(
*,
scheme_id: str,
@@ -96,6 +124,13 @@ async def create_price_rule(
price_rule_id = uuid4().hex
async with AsyncSessionLocal() as session:
await _ensure_unique_price_rule_target(
session=session,
scheme_id=scheme_id,
target_type=target_type,
target_ref=target_ref,
)
row = PriceRuleRecord(
price_rule_id=price_rule_id,
scheme_id=scheme_id,
@@ -136,6 +171,14 @@ async def update_price_rule(
detail="Price rule not found",
)
await _ensure_unique_price_rule_target(
session=session,
scheme_id=scheme_id,
target_type=target_type,
target_ref=target_ref,
exclude_price_rule_id=price_rule_id,
)
row.pricing_category_id = pricing_category_id
row.target_type = target_type
row.target_ref = target_ref

View File

@@ -0,0 +1,101 @@
from uuid import uuid4
from sqlalchemy import asc, desc, select
from app.db.session import AsyncSessionLocal
from app.models.scheme_artifact import SchemeArtifactRecord
async def create_scheme_artifact(
*,
scheme_id: str,
scheme_version_id: str,
artifact_type: str,
artifact_variant: str,
storage_path: str,
status: str = "ready",
meta_json: dict | None = None,
) -> SchemeArtifactRecord:
async with AsyncSessionLocal() as session:
row = SchemeArtifactRecord(
artifact_id=uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
artifact_type=artifact_type,
artifact_variant=artifact_variant,
storage_path=storage_path,
status=status,
meta_json=meta_json,
)
session.add(row)
await session.commit()
await session.refresh(row)
return row
async def list_scheme_artifacts(
*,
scheme_version_id: str,
artifact_type: str | None = None,
artifact_variant: str | None = None,
) -> list[SchemeArtifactRecord]:
async with AsyncSessionLocal() as session:
stmt = select(SchemeArtifactRecord).where(
SchemeArtifactRecord.scheme_version_id == scheme_version_id
)
if artifact_type is not None:
stmt = stmt.where(SchemeArtifactRecord.artifact_type == artifact_type)
if artifact_variant is not None:
stmt = stmt.where(SchemeArtifactRecord.artifact_variant == artifact_variant)
stmt = stmt.order_by(
asc(SchemeArtifactRecord.created_at),
asc(SchemeArtifactRecord.id),
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def artifact_exists(
*,
scheme_version_id: str,
artifact_type: str,
artifact_variant: str,
) -> bool:
async with AsyncSessionLocal() as session:
stmt = (
select(SchemeArtifactRecord.id)
.where(SchemeArtifactRecord.scheme_version_id == scheme_version_id)
.where(SchemeArtifactRecord.artifact_type == artifact_type)
.where(SchemeArtifactRecord.artifact_variant == artifact_variant)
.limit(1)
)
result = await session.execute(stmt)
return result.scalar_one_or_none() is not None
async def get_latest_scheme_artifact(
*,
scheme_version_id: str,
artifact_type: str,
artifact_variant: str | None = None,
) -> SchemeArtifactRecord | None:
async with AsyncSessionLocal() as session:
stmt = select(SchemeArtifactRecord).where(
SchemeArtifactRecord.scheme_version_id == scheme_version_id,
SchemeArtifactRecord.artifact_type == artifact_type,
)
if artifact_variant is not None:
stmt = stmt.where(SchemeArtifactRecord.artifact_variant == artifact_variant)
stmt = stmt.order_by(
desc(SchemeArtifactRecord.created_at),
desc(SchemeArtifactRecord.id),
).limit(1)
result = await session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -1,10 +1,11 @@
import json
from uuid import uuid4
from sqlalchemy import asc, delete, select
from fastapi import HTTPException, status
from sqlalchemy import asc, select
from app.db.session import AsyncSessionLocal
from app.models.scheme_group import SchemeGroupRecord
from app.models.scheme_seat import SchemeSeatRecord
async def replace_scheme_version_groups(
@@ -14,37 +15,29 @@ async def replace_scheme_version_groups(
groups: list[dict],
) -> None:
async with AsyncSessionLocal() as session:
await session.execute(
delete(SchemeGroupRecord).where(
SchemeGroupRecord.scheme_version_id == scheme_version_id
)
existing_result = await session.execute(
select(SchemeGroupRecord).where(SchemeGroupRecord.scheme_version_id == scheme_version_id)
)
existing_rows = list(existing_result.scalars().all())
for row in existing_rows:
await session.delete(row)
for item in groups:
row = SchemeGroupRecord(
group_record_id=uuid4().hex,
group_record_id=item["group_record_id"] if "group_record_id" in item and item["group_record_id"] else uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=item.get("id"),
group_id=item.get("group_id"),
name=item.get("group_id") or item.get("id") or "unnamed-group",
classes_raw=json.dumps(item.get("classes", []), ensure_ascii=False),
name=item.get("group_id"),
classes_raw=str(item.get("classes")),
)
session.add(row)
await session.commit()
async def list_scheme_version_groups(scheme_version_id: str) -> list[SchemeGroupRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeGroupRecord)
.where(SchemeGroupRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeGroupRecord.id))
)
return list(result.scalars().all())
async def clone_scheme_version_groups(
*,
source_scheme_version_id: str,
@@ -52,23 +45,124 @@ async def clone_scheme_version_groups(
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeGroupRecord).where(
SchemeGroupRecord.scheme_version_id == source_scheme_version_id
)
select(SchemeGroupRecord).where(SchemeGroupRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
for row in rows:
session.add(
SchemeGroupRecord(
group_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
group_id=row.group_id,
name=row.name,
classes_raw=row.classes_raw,
)
cloned = SchemeGroupRecord(
group_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
group_id=row.group_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
await session.commit()
async def list_scheme_version_groups(scheme_version_id: str) -> list[SchemeGroupRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeGroupRecord)
.where(SchemeGroupRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeGroupRecord.created_at), asc(SchemeGroupRecord.id))
)
return list(result.scalars().all())
async def update_scheme_version_group_by_record_id(
*,
scheme_version_id: str,
group_record_id: str,
group_id: str | None,
name: str | None,
) -> tuple[SchemeGroupRecord, str | None]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeGroupRecord).where(
SchemeGroupRecord.scheme_version_id == scheme_version_id,
SchemeGroupRecord.group_record_id == group_record_id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Group record not found in current draft version",
)
old_group_id = row.group_id
row.group_id = group_id
row.name = name
await session.commit()
await session.refresh(row)
return row, old_group_id
async def create_scheme_version_group(
*,
scheme_id: str,
scheme_version_id: str,
element_id: str | None,
group_id: str,
name: str | None,
classes_raw: str | None,
) -> SchemeGroupRecord:
async with AsyncSessionLocal() as session:
row = SchemeGroupRecord(
group_record_id=uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=element_id,
group_id=group_id,
name=name,
classes_raw=classes_raw,
)
session.add(row)
await session.commit()
await session.refresh(row)
return row
async def delete_scheme_version_group_by_record_id(
*,
scheme_version_id: str,
group_record_id: str,
) -> None:
async with AsyncSessionLocal() as session:
group_result = await session.execute(
select(SchemeGroupRecord).where(
SchemeGroupRecord.scheme_version_id == scheme_version_id,
SchemeGroupRecord.group_record_id == group_record_id,
)
)
group = group_result.scalar_one_or_none()
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Group record not found in current draft version",
)
if group.group_id:
seats_result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.group_id == group.group_id,
)
)
seats = list(seats_result.scalars().all())
if seats:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete group while seats still reference it",
)
await session.delete(group)
await session.commit()

View File

@@ -1,8 +1,5 @@
import json
from uuid import uuid4
from fastapi import HTTPException, status
from sqlalchemy import asc, delete, select
from sqlalchemy import asc, select
from app.db.session import AsyncSessionLocal
from app.models.scheme_seat import SchemeSeatRecord
@@ -15,15 +12,17 @@ async def replace_scheme_version_seats(
seats: list[dict],
) -> None:
async with AsyncSessionLocal() as session:
await session.execute(
delete(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id
)
existing_result = await session.execute(
select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == scheme_version_id)
)
existing_rows = list(existing_result.scalars().all())
for row in existing_rows:
await session.delete(row)
for item in seats:
row = SchemeSeatRecord(
seat_record_id=uuid4().hex,
seat_record_id=item["seat_record_id"] if "seat_record_id" in item and item["seat_record_id"] else __import__("uuid").uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=item.get("id"),
@@ -33,7 +32,7 @@ async def replace_scheme_version_seats(
row_label=item.get("row"),
seat_number=item.get("seat_number"),
tag=item.get("tag"),
classes_raw=json.dumps(item.get("classes", []), ensure_ascii=False),
classes_raw=str(item.get("classes")),
x=item.get("x"),
y=item.get("y"),
cx=item.get("cx"),
@@ -46,12 +45,48 @@ async def replace_scheme_version_seats(
await session.commit()
async def clone_scheme_version_seats(
*,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(SchemeSeatRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
for row in rows:
cloned = SchemeSeatRecord(
seat_record_id=__import__("uuid").uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
seat_id=row.seat_id,
sector_id=row.sector_id,
group_id=row.group_id,
row_label=row.row_label,
seat_number=row.seat_number,
tag=row.tag,
classes_raw=row.classes_raw,
x=row.x,
y=row.y,
cx=row.cx,
cy=row.cy,
width=row.width,
height=row.height,
)
session.add(cloned)
await session.commit()
async def list_scheme_version_seats(scheme_version_id: str) -> list[SchemeSeatRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord)
.where(SchemeSeatRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeSeatRecord.id))
.order_by(asc(SchemeSeatRecord.created_at), asc(SchemeSeatRecord.id))
)
return list(result.scalars().all())
@@ -79,40 +114,199 @@ async def get_scheme_version_seat_by_seat_id(
return row
async def clone_scheme_version_seats(
async def update_scheme_version_seat_by_record_id(
*,
source_scheme_version_id: str,
target_scheme_version_id: str,
) -> None:
scheme_version_id: str,
seat_record_id: str,
seat_id: str | None,
sector_id: str | None,
group_id: str | None,
row_label: str | None,
seat_number: str | None,
) -> SchemeSeatRecord:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == source_scheme_version_id
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.seat_record_id == seat_record_id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Seat record not found in current draft version",
)
row.seat_id = seat_id
row.sector_id = sector_id
row.group_id = group_id
row.row_label = row_label
row.seat_number = seat_number
await session.commit()
await session.refresh(row)
return row
async def bulk_update_scheme_version_seats_by_record_id(
*,
scheme_version_id: str,
items: list[dict],
) -> list[SchemeSeatRecord]:
updated_rows: list[SchemeSeatRecord] = []
async with AsyncSessionLocal() as session:
for item in items:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.seat_record_id == item["seat_record_id"],
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Seat record not found in current draft version: {item['seat_record_id']}",
)
row.seat_id = item.get("seat_id")
row.sector_id = item.get("sector_id")
row.group_id = item.get("group_id")
row.row_label = item.get("row_label")
row.seat_number = item.get("seat_number")
updated_rows.append(row)
await session.commit()
for row in updated_rows:
await session.refresh(row)
return updated_rows
async def bulk_remap_scheme_version_seats(
*,
scheme_version_id: str,
items: list[dict],
) -> None:
async with AsyncSessionLocal() as session:
for item in items:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.seat_record_id == item["seat_record_id"],
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Seat record not found in current draft version: {item['seat_record_id']}",
)
row.sector_id = item["after_sector_id"]
row.group_id = item["after_group_id"]
await session.commit()
async def cascade_update_seat_sector_reference(
*,
scheme_version_id: str,
old_sector_id: str | None,
new_sector_id: str | None,
) -> int:
if not old_sector_id or old_sector_id == new_sector_id:
return 0
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.sector_id == old_sector_id,
)
)
rows = list(result.scalars().all())
for row in rows:
session.add(
SchemeSeatRecord(
seat_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
seat_id=row.seat_id,
sector_id=row.sector_id,
group_id=row.group_id,
row_label=row.row_label,
seat_number=row.seat_number,
tag=row.tag,
classes_raw=row.classes_raw,
x=row.x,
y=row.y,
cx=row.cx,
cy=row.cy,
width=row.width,
height=row.height,
)
)
row.sector_id = new_sector_id
await session.commit()
return len(rows)
async def cascade_update_seat_group_reference(
*,
scheme_version_id: str,
old_group_id: str | None,
new_group_id: str | None,
) -> int:
if not old_group_id or old_group_id == new_group_id:
return 0
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.group_id == old_group_id,
)
)
rows = list(result.scalars().all())
for row in rows:
row.group_id = new_group_id
await session.commit()
return len(rows)
async def repair_orphan_sector_refs(
*,
scheme_version_id: str,
new_sector_id: str,
orphan_values: list[str],
) -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id
)
)
rows = list(result.scalars().all())
changed = 0
for row in rows:
if row.sector_id in orphan_values:
row.sector_id = new_sector_id
changed += 1
await session.commit()
return changed
async def repair_orphan_group_refs(
*,
scheme_version_id: str,
new_group_id: str,
orphan_values: list[str],
) -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id
)
)
rows = list(result.scalars().all())
changed = 0
for row in rows:
if row.group_id in orphan_values:
row.group_id = new_group_id
changed += 1
await session.commit()
return changed

View File

@@ -1,10 +1,11 @@
import json
from uuid import uuid4
from sqlalchemy import asc, delete, select
from fastapi import HTTPException, status
from sqlalchemy import asc, select
from app.db.session import AsyncSessionLocal
from app.models.scheme_sector import SchemeSectorRecord
from app.models.scheme_seat import SchemeSeatRecord
async def replace_scheme_version_sectors(
@@ -14,37 +15,29 @@ async def replace_scheme_version_sectors(
sectors: list[dict],
) -> None:
async with AsyncSessionLocal() as session:
await session.execute(
delete(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id
)
existing_result = await session.execute(
select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == scheme_version_id)
)
existing_rows = list(existing_result.scalars().all())
for row in existing_rows:
await session.delete(row)
for item in sectors:
row = SchemeSectorRecord(
sector_record_id=uuid4().hex,
sector_record_id=item["sector_record_id"] if "sector_record_id" in item and item["sector_record_id"] else uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=item.get("id"),
sector_id=item.get("sector_id"),
name=item.get("sector_id") or item.get("id") or "unnamed-sector",
classes_raw=json.dumps(item.get("classes", []), ensure_ascii=False),
name=item.get("sector_id"),
classes_raw=str(item.get("classes")),
)
session.add(row)
await session.commit()
async def list_scheme_version_sectors(scheme_version_id: str) -> list[SchemeSectorRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord)
.where(SchemeSectorRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeSectorRecord.id))
)
return list(result.scalars().all())
async def clone_scheme_version_sectors(
*,
source_scheme_version_id: str,
@@ -52,23 +45,124 @@ async def clone_scheme_version_sectors(
) -> None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == source_scheme_version_id
)
select(SchemeSectorRecord).where(SchemeSectorRecord.scheme_version_id == source_scheme_version_id)
)
rows = list(result.scalars().all())
for row in rows:
session.add(
SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
sector_id=row.sector_id,
name=row.name,
classes_raw=row.classes_raw,
)
cloned = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=row.scheme_id,
scheme_version_id=target_scheme_version_id,
element_id=row.element_id,
sector_id=row.sector_id,
name=row.name,
classes_raw=row.classes_raw,
)
session.add(cloned)
await session.commit()
async def list_scheme_version_sectors(scheme_version_id: str) -> list[SchemeSectorRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord)
.where(SchemeSectorRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeSectorRecord.created_at), asc(SchemeSectorRecord.id))
)
return list(result.scalars().all())
async def update_scheme_version_sector_by_record_id(
*,
scheme_version_id: str,
sector_record_id: str,
sector_id: str | None,
name: str | None,
) -> tuple[SchemeSectorRecord, str | None]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_record_id == sector_record_id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sector record not found in current draft version",
)
old_sector_id = row.sector_id
row.sector_id = sector_id
row.name = name
await session.commit()
await session.refresh(row)
return row, old_sector_id
async def create_scheme_version_sector(
*,
scheme_id: str,
scheme_version_id: str,
element_id: str | None,
sector_id: str,
name: str | None,
classes_raw: str | None,
) -> SchemeSectorRecord:
async with AsyncSessionLocal() as session:
row = SchemeSectorRecord(
sector_record_id=uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
element_id=element_id,
sector_id=sector_id,
name=name,
classes_raw=classes_raw,
)
session.add(row)
await session.commit()
await session.refresh(row)
return row
async def delete_scheme_version_sector_by_record_id(
*,
scheme_version_id: str,
sector_record_id: str,
) -> None:
async with AsyncSessionLocal() as session:
sector_result = await session.execute(
select(SchemeSectorRecord).where(
SchemeSectorRecord.scheme_version_id == scheme_version_id,
SchemeSectorRecord.sector_record_id == sector_record_id,
)
)
sector = sector_result.scalar_one_or_none()
if sector is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sector record not found in current draft version",
)
if sector.sector_id:
seats_result = await session.execute(
select(SchemeSeatRecord).where(
SchemeSeatRecord.scheme_version_id == scheme_version_id,
SchemeSeatRecord.sector_id == sector.sector_id,
)
)
seats = list(seats_result.scalars().all())
if seats:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Cannot delete sector while seats still reference it",
)
await session.delete(sector)
await session.commit()

View File

@@ -0,0 +1,144 @@
from uuid import uuid4
from fastapi import HTTPException, status
from sqlalchemy import asc, desc, select
from app.db.session import AsyncSessionLocal
from app.models.scheme_version_pricing import (
SchemeVersionPriceRuleRecord,
SchemeVersionPricingCategoryRecord,
)
from app.repositories.pricing import list_price_rules, list_pricing_categories
async def replace_scheme_version_pricing_snapshot(
*,
scheme_id: str,
scheme_version_id: str,
) -> dict:
categories = await list_pricing_categories(scheme_id)
rules = await list_price_rules(scheme_id)
async with AsyncSessionLocal() as session:
old_categories = await session.execute(
select(SchemeVersionPricingCategoryRecord).where(
SchemeVersionPricingCategoryRecord.scheme_version_id == scheme_version_id
)
)
for row in list(old_categories.scalars().all()):
await session.delete(row)
old_rules = await session.execute(
select(SchemeVersionPriceRuleRecord).where(
SchemeVersionPriceRuleRecord.scheme_version_id == scheme_version_id
)
)
for row in list(old_rules.scalars().all()):
await session.delete(row)
mapping: dict[str, str] = {}
for category in categories:
snapshot_category_id = uuid4().hex
mapping[category.pricing_category_id] = snapshot_category_id
session.add(
SchemeVersionPricingCategoryRecord(
snapshot_category_id=snapshot_category_id,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
source_pricing_category_id=category.pricing_category_id,
name=category.name,
code=category.code,
)
)
for rule in rules:
session.add(
SchemeVersionPriceRuleRecord(
snapshot_price_rule_id=uuid4().hex,
scheme_id=scheme_id,
scheme_version_id=scheme_version_id,
source_price_rule_id=rule.price_rule_id,
snapshot_category_id=mapping.get(rule.pricing_category_id) if rule.pricing_category_id else None,
target_type=rule.target_type,
target_ref=rule.target_ref,
amount=rule.amount,
currency=rule.currency,
)
)
await session.commit()
return {
"categories_count": len(categories),
"rules_count": len(rules),
}
async def list_scheme_version_snapshot_categories(
scheme_version_id: str,
) -> list[SchemeVersionPricingCategoryRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeVersionPricingCategoryRecord)
.where(SchemeVersionPricingCategoryRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeVersionPricingCategoryRecord.created_at), asc(SchemeVersionPricingCategoryRecord.id))
)
return list(result.scalars().all())
async def list_scheme_version_snapshot_rules(
scheme_version_id: str,
) -> list[SchemeVersionPriceRuleRecord]:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(SchemeVersionPriceRuleRecord)
.where(SchemeVersionPriceRuleRecord.scheme_version_id == scheme_version_id)
.order_by(asc(SchemeVersionPriceRuleRecord.created_at), asc(SchemeVersionPriceRuleRecord.id))
)
return list(result.scalars().all())
async def find_effective_snapshot_price_rule(
*,
scheme_version_id: str,
seat_id: str | None,
group_id: str | None,
sector_id: str | None,
) -> tuple[str, dict]:
async with AsyncSessionLocal() as session:
checks = [
("seat", seat_id),
("group", group_id),
("sector", sector_id),
]
for level, ref in checks:
if not ref:
continue
result = await session.execute(
select(SchemeVersionPriceRuleRecord)
.where(
SchemeVersionPriceRuleRecord.scheme_version_id == scheme_version_id,
SchemeVersionPriceRuleRecord.target_type == level,
SchemeVersionPriceRuleRecord.target_ref == ref,
)
.order_by(desc(SchemeVersionPriceRuleRecord.created_at), desc(SchemeVersionPriceRuleRecord.id))
.limit(1)
)
row = result.scalar_one_or_none()
if row is not None:
return level, {
"snapshot_price_rule_id": row.snapshot_price_rule_id,
"snapshot_category_id": row.snapshot_category_id,
"target_type": row.target_type,
"target_ref": row.target_ref,
"amount": row.amount,
"currency": row.currency,
}
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No snapshot pricing rule matched current seat",
)