i12e/backend/main.py

156 lines
5.4 KiB
Python

import logging
import re
from typing import List
import psycopg2
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from psycopg2.extras import RealDictCursor
from pydantic import BaseModel, validator
log = logging.getLogger(__name__)
class CountryBands(BaseModel):
uuid: str
country: str
bands: List[str]
@validator("uuid")
def is_uuid(cls, v):
assert re.match(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", v)
return v
app = FastAPI()
db = psycopg2.connect(dbname="internationale")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/update")
def update_country(cb: CountryBands):
try:
with db.cursor(cursor_factory=RealDictCursor) as csr:
csr.execute(
"select * from entries where uuid = %s and country = %s",
(cb.uuid, cb.country))
bands = [r["band"] for r in csr]
for b in cb.bands:
if b in bands:
bands.remove(b)
else:
csr.execute(
"select * from users where uuid = %s",
(cb.uuid,))
if not csr.fetchall():
csr.execute(
"insert into users(uuid, public_id) values(%s, %s)",
(cb.uuid, cb.uuid[24:],))
csr.execute(
"insert into entries(uuid, country, band) values(%s, %s, %s)",
(cb.uuid, cb.country, b))
for b in bands:
csr.execute(
"""
delete from entries
where uuid = %s and country = %s and band = %s
""",
(cb.uuid, cb.country, b))
db.commit()
csr.execute(
"select count(distinct band) from entries where country = %s",
(cb.country,))
count = csr.fetchone()["count"]
return count
except Exception as e:
log.exception("caught error")
db.rollback()
return HTTPException(500)
@app.get("/countries/{lang}")
def get_countries(lang: str):
try:
with db.cursor(cursor_factory=RealDictCursor) as csr:
csr.execute(
"""
select * from countries join country_names using (code)
where lang = %s
order by population desc
""",
(lang,))
return csr.fetchall()
except Exception as e:
log.exception("caught error")
db.rollback()
return HTTPException(500)
@app.get("/stats/{id}/{lang}")
def get_stats(id: str, lang: str):
try:
with db.cursor(cursor_factory=RealDictCursor) as csr:
csr.execute(
"""
select country, name,
band,
count(*) as total,
count(*) filter(where public_id = %s) as by_user
from entries e
join countries c on c.code = e.country
join country_names cn on e.country = cn.code and cn.lang = %s
join users u using (uuid)
group by country, name, population, band
order by population desc, name, band
""",
(id, lang,))
country_stats = []
for r in csr:
if not country_stats or country_stats[-1]["code"] != r["country"]:
country_stats.append({"code": r["country"],
"name": r["name"],
"bands": []})
country_stats[-1]["bands"].append({"name": r["band"],
"total_count": r["total"],
"by_user": bool(r["by_user"])})
stats = { "countries": country_stats }
csr.execute(
"""
with a as (
select public_id, count(distinct(country)) as c
from entries join users using (uuid)
group by 1)
select public_id, c, rank() over (order by c desc) from a;
"""
)
for r in csr:
if r["public_id"] == id:
stats["country_rank"] = r["rank"]
stats["country_count"] = r["c"]
stats["user_count"] = csr.rowcount
csr.execute(
"""
with a as (
select public_id, count(*) as c
from entries join users using (uuid)
group by 1)
select public_id, c, rank() over (order by c desc) from a;
"""
)
for r in csr:
if r["public_id"] == id:
stats["country_rank"] = r["rank"]
stats["country_count"] = r["c"]
return stats
except Exception as e:
log.exception("caught error")
db.rollback()
return HTTPException(500)
# vim: tw=99