from uuid import uuid4 from fastapi import HTTPException, status from sqlalchemy import asc, desc, select from app.db.session import AsyncSessionLocal from app.models.scheme_version_pricing import ( SchemeVersionPriceRuleRecord, SchemeVersionPricingCategoryRecord, ) from app.repositories.pricing import list_price_rules, list_pricing_categories async def replace_scheme_version_pricing_snapshot( *, scheme_id: str, scheme_version_id: str, ) -> dict: categories = await list_pricing_categories(scheme_id) rules = await list_price_rules(scheme_id) async with AsyncSessionLocal() as session: old_categories = await session.execute( select(SchemeVersionPricingCategoryRecord).where( SchemeVersionPricingCategoryRecord.scheme_version_id == scheme_version_id ) ) for row in list(old_categories.scalars().all()): await session.delete(row) old_rules = await session.execute( select(SchemeVersionPriceRuleRecord).where( SchemeVersionPriceRuleRecord.scheme_version_id == scheme_version_id ) ) for row in list(old_rules.scalars().all()): await session.delete(row) mapping: dict[str, str] = {} for category in categories: snapshot_category_id = uuid4().hex mapping[category.pricing_category_id] = snapshot_category_id session.add( SchemeVersionPricingCategoryRecord( snapshot_category_id=snapshot_category_id, scheme_id=scheme_id, scheme_version_id=scheme_version_id, source_pricing_category_id=category.pricing_category_id, name=category.name, code=category.code, ) ) for rule in rules: session.add( SchemeVersionPriceRuleRecord( snapshot_price_rule_id=uuid4().hex, scheme_id=scheme_id, scheme_version_id=scheme_version_id, source_price_rule_id=rule.price_rule_id, snapshot_category_id=mapping.get(rule.pricing_category_id) if rule.pricing_category_id else None, target_type=rule.target_type, target_ref=rule.target_ref, amount=rule.amount, currency=rule.currency, ) ) await session.commit() return { "categories_count": len(categories), "rules_count": len(rules), } async def list_scheme_version_snapshot_categories( scheme_version_id: str, ) -> list[SchemeVersionPricingCategoryRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeVersionPricingCategoryRecord) .where(SchemeVersionPricingCategoryRecord.scheme_version_id == scheme_version_id) .order_by(asc(SchemeVersionPricingCategoryRecord.created_at), asc(SchemeVersionPricingCategoryRecord.id)) ) return list(result.scalars().all()) async def list_scheme_version_snapshot_rules( scheme_version_id: str, ) -> list[SchemeVersionPriceRuleRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(SchemeVersionPriceRuleRecord) .where(SchemeVersionPriceRuleRecord.scheme_version_id == scheme_version_id) .order_by(asc(SchemeVersionPriceRuleRecord.created_at), asc(SchemeVersionPriceRuleRecord.id)) ) return list(result.scalars().all()) async def find_effective_snapshot_price_rule( *, scheme_version_id: str, seat_id: str | None, group_id: str | None, sector_id: str | None, ) -> tuple[str, dict]: async with AsyncSessionLocal() as session: checks = [ ("seat", seat_id), ("group", group_id), ("sector", sector_id), ] for level, ref in checks: if not ref: continue result = await session.execute( select(SchemeVersionPriceRuleRecord) .where( SchemeVersionPriceRuleRecord.scheme_version_id == scheme_version_id, SchemeVersionPriceRuleRecord.target_type == level, SchemeVersionPriceRuleRecord.target_ref == ref, ) .order_by(desc(SchemeVersionPriceRuleRecord.created_at), desc(SchemeVersionPriceRuleRecord.id)) .limit(1) ) row = result.scalar_one_or_none() if row is not None: return level, { "snapshot_price_rule_id": row.snapshot_price_rule_id, "snapshot_category_id": row.snapshot_category_id, "target_type": row.target_type, "target_ref": row.target_ref, "amount": row.amount, "currency": row.currency, } raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No snapshot pricing rule matched current seat", )