from __future__ import annotations from importlib import import_module from sqlalchemy import delete, func, outerjoin, select from app.db.session import AsyncSessionLocal def _resolve_model(module_path: str, *candidate_names: str): module = import_module(module_path) for name in candidate_names: model = getattr(module, name, None) if model is not None: return model raise ImportError( f"Unable to resolve model from {module_path}. " f"Tried: {', '.join(candidate_names)}" ) PricingCategoryModel = _resolve_model( "app.models.pricing_category", "PricingCategory", "PricingCategoryRecord", ) PriceRuleModel = _resolve_model( "app.models.price_rule", "PriceRule", "PriceRuleRecord", ) async def list_pricing_categories_with_rule_counts( *, scheme_id: str, ) -> list[dict]: async with AsyncSessionLocal() as session: stmt = ( select( PricingCategoryModel.pricing_category_id, PricingCategoryModel.scheme_id, PricingCategoryModel.name, PricingCategoryModel.code, func.count(PriceRuleModel.price_rule_id).label("rules_count"), ) .select_from( outerjoin( PricingCategoryModel, PriceRuleModel, PricingCategoryModel.pricing_category_id == PriceRuleModel.pricing_category_id, ) ) .where(PricingCategoryModel.scheme_id == scheme_id) .group_by( PricingCategoryModel.pricing_category_id, PricingCategoryModel.scheme_id, PricingCategoryModel.name, PricingCategoryModel.code, ) .order_by( PricingCategoryModel.name.asc(), PricingCategoryModel.code.asc(), PricingCategoryModel.pricing_category_id.asc(), ) ) rows = (await session.execute(stmt)).all() return [ { "pricing_category_id": row.pricing_category_id, "scheme_id": row.scheme_id, "name": row.name, "code": row.code, "rules_count": int(row.rules_count or 0), } for row in rows ] async def delete_pricing_categories_by_ids( *, scheme_id: str, pricing_category_ids: list[str], ) -> int: if not pricing_category_ids: return 0 async with AsyncSessionLocal() as session: stmt = delete(PricingCategoryModel).where( PricingCategoryModel.scheme_id == scheme_id, PricingCategoryModel.pricing_category_id.in_(pricing_category_ids), ) result = await session.execute(stmt) await session.commit() return int(result.rowcount or 0)