from decimal import Decimal from uuid import uuid4 from fastapi import HTTPException, status from sqlalchemy import asc, desc, select from app.db.session import AsyncSessionLocal from app.models.price_rule import PriceRuleRecord from app.models.pricing_category import PricingCategoryRecord async def create_pricing_category( *, scheme_id: str, name: str, code: str | None, ) -> str: pricing_category_id = uuid4().hex async with AsyncSessionLocal() as session: row = PricingCategoryRecord( pricing_category_id=pricing_category_id, scheme_id=scheme_id, name=name, code=code, ) session.add(row) await session.commit() return pricing_category_id async def update_pricing_category( *, scheme_id: str, pricing_category_id: str, name: str, code: str | None, ) -> PricingCategoryRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(PricingCategoryRecord).where( PricingCategoryRecord.scheme_id == scheme_id, PricingCategoryRecord.pricing_category_id == pricing_category_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Pricing category not found", ) row.name = name row.code = code await session.commit() await session.refresh(row) return row async def delete_pricing_category( *, scheme_id: str, pricing_category_id: str, ) -> None: async with AsyncSessionLocal() as session: result = await session.execute( select(PricingCategoryRecord).where( PricingCategoryRecord.scheme_id == scheme_id, PricingCategoryRecord.pricing_category_id == pricing_category_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Pricing category not found", ) await session.delete(row) await session.commit() async def create_price_rule( *, scheme_id: str, pricing_category_id: str | None, target_type: str, target_ref: str, amount: Decimal, currency: str, ) -> str: price_rule_id = uuid4().hex async with AsyncSessionLocal() as session: row = PriceRuleRecord( price_rule_id=price_rule_id, scheme_id=scheme_id, pricing_category_id=pricing_category_id, target_type=target_type, target_ref=target_ref, amount=amount, currency=currency, ) session.add(row) await session.commit() return price_rule_id async def update_price_rule( *, scheme_id: str, price_rule_id: str, pricing_category_id: str | None, target_type: str, target_ref: str, amount: Decimal, currency: str, ) -> PriceRuleRecord: async with AsyncSessionLocal() as session: result = await session.execute( select(PriceRuleRecord).where( PriceRuleRecord.scheme_id == scheme_id, PriceRuleRecord.price_rule_id == price_rule_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Price rule not found", ) row.pricing_category_id = pricing_category_id row.target_type = target_type row.target_ref = target_ref row.amount = amount row.currency = currency await session.commit() await session.refresh(row) return row async def delete_price_rule( *, scheme_id: str, price_rule_id: str, ) -> None: async with AsyncSessionLocal() as session: result = await session.execute( select(PriceRuleRecord).where( PriceRuleRecord.scheme_id == scheme_id, PriceRuleRecord.price_rule_id == price_rule_id, ) ) row = result.scalar_one_or_none() if row is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Price rule not found", ) await session.delete(row) await session.commit() async def list_pricing_categories(scheme_id: str) -> list[PricingCategoryRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(PricingCategoryRecord) .where(PricingCategoryRecord.scheme_id == scheme_id) .order_by(asc(PricingCategoryRecord.created_at), asc(PricingCategoryRecord.id)) ) return list(result.scalars().all()) async def list_price_rules(scheme_id: str) -> list[PriceRuleRecord]: async with AsyncSessionLocal() as session: result = await session.execute( select(PriceRuleRecord) .where(PriceRuleRecord.scheme_id == scheme_id) .order_by(asc(PriceRuleRecord.created_at), asc(PriceRuleRecord.id)) ) return list(result.scalars().all()) async def find_effective_price_rule( *, scheme_id: str, seat_id: str | None, group_id: str | None, sector_id: str | None, ) -> tuple[str, dict]: async with AsyncSessionLocal() as session: checks = [ ("seat", seat_id), ("group", group_id), ("sector", sector_id), ] for level, ref in checks: if not ref: continue result = await session.execute( select(PriceRuleRecord) .where( PriceRuleRecord.scheme_id == scheme_id, PriceRuleRecord.target_type == level, PriceRuleRecord.target_ref == ref, ) .order_by(desc(PriceRuleRecord.created_at), desc(PriceRuleRecord.id)) .limit(1) ) row = result.scalar_one_or_none() if row is not None: return level, { "price_rule_id": row.price_rule_id, "scheme_id": row.scheme_id, "pricing_category_id": row.pricing_category_id, "target_type": row.target_type, "target_ref": row.target_ref, "amount": row.amount, "currency": row.currency, } raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No pricing rule matched current seat", )