156 lines
5.4 KiB
Python
156 lines
5.4 KiB
Python
import logging
|
|
import re
|
|
from typing import List
|
|
|
|
import psycopg2
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from psycopg2.extras import RealDictCursor
|
|
from pydantic import BaseModel, validator
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
class CountryBands(BaseModel):
|
|
uuid: str
|
|
country: str
|
|
bands: List[str]
|
|
|
|
@validator("uuid")
|
|
def is_uuid(cls, v):
|
|
assert re.match(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", v)
|
|
return v
|
|
|
|
app = FastAPI()
|
|
db = psycopg2.connect(dbname="internationale")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=False,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.post("/update")
|
|
def update_country(cb: CountryBands):
|
|
try:
|
|
with db.cursor(cursor_factory=RealDictCursor) as csr:
|
|
csr.execute(
|
|
"select * from entries where uuid = %s and country = %s",
|
|
(cb.uuid, cb.country))
|
|
bands = [r["band"] for r in csr]
|
|
for b in cb.bands:
|
|
if b in bands:
|
|
bands.remove(b)
|
|
else:
|
|
csr.execute(
|
|
"select * from users where uuid = %s",
|
|
(cb.uuid,))
|
|
if not csr.fetchall():
|
|
csr.execute(
|
|
"insert into users(uuid, public_id) values(%s, %s)",
|
|
(cb.uuid, cb.uuid[24:],))
|
|
csr.execute(
|
|
"insert into entries(uuid, country, band) values(%s, %s, %s)",
|
|
(cb.uuid, cb.country, b))
|
|
for b in bands:
|
|
csr.execute(
|
|
"""
|
|
delete from entries
|
|
where uuid = %s and country = %s and band = %s
|
|
""",
|
|
(cb.uuid, cb.country, b))
|
|
db.commit()
|
|
csr.execute(
|
|
"select count(distinct band) from entries where country = %s",
|
|
(cb.country,))
|
|
count = csr.fetchone()["count"]
|
|
return count
|
|
except Exception as e:
|
|
log.exception("caught error")
|
|
db.rollback()
|
|
return HTTPException(500)
|
|
|
|
@app.get("/countries/{lang}")
|
|
def get_countries(lang: str):
|
|
try:
|
|
with db.cursor(cursor_factory=RealDictCursor) as csr:
|
|
csr.execute(
|
|
"""
|
|
select * from countries join country_names using (code)
|
|
where lang = %s
|
|
order by population desc
|
|
""",
|
|
(lang,))
|
|
return csr.fetchall()
|
|
except Exception as e:
|
|
log.exception("caught error")
|
|
db.rollback()
|
|
return HTTPException(500)
|
|
|
|
@app.get("/stats/{id}/{lang}")
|
|
def get_stats(id: str, lang: str):
|
|
try:
|
|
with db.cursor(cursor_factory=RealDictCursor) as csr:
|
|
csr.execute(
|
|
"""
|
|
select country, name,
|
|
band,
|
|
count(*) as total,
|
|
count(*) filter(where public_id = %s) as by_user
|
|
from entries e
|
|
join countries c on c.code = e.country
|
|
join country_names cn on e.country = cn.code and cn.lang = %s
|
|
join users u using (uuid)
|
|
group by country, name, population, band
|
|
order by population desc, name, band
|
|
""",
|
|
(id, lang,))
|
|
country_stats = []
|
|
for r in csr:
|
|
if not country_stats or country_stats[-1]["code"] != r["country"]:
|
|
country_stats.append({"code": r["country"],
|
|
"name": r["name"],
|
|
"bands": []})
|
|
country_stats[-1]["bands"].append({"name": r["band"],
|
|
"total_count": r["total"],
|
|
"by_user": bool(r["by_user"])})
|
|
|
|
stats = { "countries": country_stats }
|
|
csr.execute(
|
|
"""
|
|
with a as (
|
|
select public_id, count(distinct(country)) as c
|
|
from entries join users using (uuid)
|
|
group by 1)
|
|
select public_id, c, rank() over (order by c desc) from a;
|
|
"""
|
|
)
|
|
for r in csr:
|
|
if r["public_id"] == id:
|
|
stats["country_rank"] = r["rank"]
|
|
stats["country_count"] = r["c"]
|
|
stats["user_count"] = csr.rowcount
|
|
|
|
csr.execute(
|
|
"""
|
|
with a as (
|
|
select public_id, count(*) as c
|
|
from entries join users using (uuid)
|
|
group by 1)
|
|
select public_id, c, rank() over (order by c desc) from a;
|
|
"""
|
|
)
|
|
for r in csr:
|
|
if r["public_id"] == id:
|
|
stats["country_rank"] = r["rank"]
|
|
stats["country_count"] = r["c"]
|
|
|
|
return stats
|
|
except Exception as e:
|
|
log.exception("caught error")
|
|
db.rollback()
|
|
return HTTPException(500)
|
|
|
|
# vim: tw=99
|