import logging import re from typing import List import psycopg2 from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from psycopg2.extras import RealDictCursor from pydantic import BaseModel, validator log = logging.getLogger(__name__) class CountryBands(BaseModel): uuid: str country: str bands: List[str] @validator("uuid") def is_uuid(cls, v): assert re.match(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", v) return v app = FastAPI() db = psycopg2.connect(dbname="internationale") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) @app.post("/update") def update_country(cb: CountryBands): try: with db.cursor(cursor_factory=RealDictCursor) as csr: csr.execute( "select * from entries where uuid = %s and country = %s", (cb.uuid, cb.country)) bands = [r["band"] for r in csr] for b in cb.bands: if b in bands: bands.remove(b) else: csr.execute( "select * from users where uuid = %s", (cb.uuid,)) if not csr.fetchall(): csr.execute( "insert into users(uuid, public_id) values(%s, %s)", (cb.uuid, cb.uuid[24:],)) csr.execute( "insert into entries(uuid, country, band) values(%s, %s, %s)", (cb.uuid, cb.country, b)) for b in bands: csr.execute( """ delete from entries where uuid = %s and country = %s and band = %s """, (cb.uuid, cb.country, b)) db.commit() csr.execute( "select count(distinct band) from entries where country = %s", (cb.country,)) count = csr.fetchone()["count"] return count except Exception as e: log.exception("caught error") db.rollback() return HTTPException(500) @app.get("/countries/{lang}") def get_countries(lang: str): try: with db.cursor(cursor_factory=RealDictCursor) as csr: csr.execute( """ select * from countries join country_names using (code) where lang = %s order by population desc """, (lang,)) return csr.fetchall() except Exception as e: log.exception("caught error") db.rollback() return HTTPException(500) @app.get("/stats/{id}/{lang}") def get_stats(id: str, lang: str): try: with db.cursor(cursor_factory=RealDictCursor) as csr: csr.execute( """ select country, name, band, count(*) as total, count(*) filter(where public_id = %s) as by_user from entries e join countries c on c.code = e.country join country_names cn on e.country = cn.code and cn.lang = %s join users u using (uuid) group by country, name, population, band order by population desc, name, band """, (id, lang,)) country_stats = [] for r in csr: if not country_stats or country_stats[-1]["code"] != r["country"]: country_stats.append({"code": r["country"], "name": r["name"], "bands": []}) country_stats[-1]["bands"].append({"name": r["band"], "total_count": r["total"], "by_user": bool(r["by_user"])}) stats = { "countries": country_stats } csr.execute( """ with a as ( select public_id, count(distinct(country)) as c from entries join users using (uuid) group by 1) select public_id, c, rank() over (order by c desc) from a; """ ) for r in csr: if r["public_id"] == id: stats["country_rank"] = r["rank"] stats["country_count"] = r["c"] stats["user_count"] = csr.rowcount csr.execute( """ with a as ( select public_id, count(*) as c from entries join users using (uuid) group by 1) select public_id, c, rank() over (order by c desc) from a; """ ) for r in csr: if r["public_id"] == id: stats["country_rank"] = r["rank"] stats["country_count"] = r["c"] return stats except Exception as e: log.exception("caught error") db.rollback() return HTTPException(500) # vim: tw=99