view src/main.rs @ 8:6e4ad5da0f58

Restructure database access
author Lewin Bormann <lbo@spheniscida.de>
date Sun, 10 Jul 2022 09:29:04 -0700
parents 204877b11751
children 71c5459ec0cc
line wrap: on
line source

use anyhow::{self, Error};
use either::Either;
use log::{debug, error, info, Level};

#[macro_use]
use rocket::{self};

use rocket::form::Form;
use rocket::fs::NamedFile;
use rocket::http::{Cookie, CookieJar, Header, HeaderMap, Status};
use rocket::request::{self, FlashMessage, FromRequest, Outcome, Request};
use rocket::response::{self, Flash, Redirect, Responder, Response};

use rocket_db_pools::sqlx::{
    self, pool::PoolConnection, Executor, PgPool, Postgres, Row, Sqlite, SqlitePool, Statement,
};
use rocket_db_pools::{Connection, Database, Pool};

use rocket_dyn_templates::{context, Template};

use std::net::IpAddr;
use std::path::Path;

use tokio::fs::{self, File};

#[cfg(feature = "sqlite")]
type DBType = Sqlite;
#[cfg(feature = "sqlite")]
type PoolType = SqlitePool;

// Current SQL queries don't work with postgres.
#[cfg(feature = "postgres")]
type DBType = Postgres;
#[cfg(feature = "postgres")]
type PoolType = PgPool;

// TO DO: use other databases?
#[derive(Database)]
#[database("sqlite_main")]
struct ConfigDB(PoolType);

struct ConfigDBSession<'p, DB: sqlx::Database>(&'p mut sqlx::pool::PoolConnection<DB>);

impl<'p> ConfigDBSession<'p, Sqlite> {
    async fn check_user_password<S: AsRef<str>>(
        &mut self,
        user: S,
        password: S,
    ) -> Result<bool, Error> {
        // TODO: salt passwords.
        let pwdhash = sha256::digest(password.as_ref());
        let q = sqlx::query("SELECT username FROM users WHERE username = ? AND password_hash = ?")
            .bind(user.as_ref())
            .bind(pwdhash);
        let result = self.0.fetch_all(q).await?;
        Ok(result.len() == 1)
    }
}

#[derive(Database)]
#[database("sqlite_logs")]
struct LogsDB(PoolType);

struct LogsDBSession<'p, DB: sqlx::Database>(&'p mut sqlx::pool::PoolConnection<DB>);

impl<'p> LogsDBSession<'p, Sqlite> {
    async fn log_request<
        S1: AsRef<str>,
        S2: AsRef<str>,
        S3: AsRef<str>,
        S4: AsRef<str>,
        S5: AsRef<str>,
        S6: AsRef<str>,
    >(
        &mut self,
        session: Option<u32>,
        ip: S1,
        domain: S2,
        path: S3,
        status: u32,
        page: Option<S4>,
        refer: Option<S5>,
        ua: S6,
        ntags: u32,
    ) -> Result<u32, Error> {
        let q = sqlx::query::<DBType>("INSERT INTO RequestLog (session, ip, atime, domain, path, status, pagename, refer, ua, ntags) VALUES (?, ?, strftime('%s', 'now'), ?, ?, ?, ?, ?, ?, ?) RETURNING id");
        let q = q
            .bind(session)
            .bind(ip.as_ref())
            .bind(domain.as_ref())
            .bind(path.as_ref())
            .bind(status)
            .bind(page.map(|s| s.as_ref().to_string()))
            .bind(refer.map(|s| s.as_ref().to_string()))
            .bind(ua.as_ref())
            .bind(ntags);
        let row: u32 = q.fetch_one(&mut *self.0).await?.get(0);
        Ok(row)
    }

    async fn log_tags<S: AsRef<str>, I: Iterator<Item = S>>(
        &mut self,
        requestid: u32,
        tags: I,
    ) -> Result<usize, Error> {
        let mut ntags = 0;
        for tag in tags {
            let (k, v) = tag.as_ref().split_once("=").unwrap_or((tag.as_ref(), ""));
            sqlx::query("INSERT INTO RequestTags (requestid, key, value) VALUES (?, ?, ?)")
                .bind(requestid)
                .bind(k)
                .bind(v)
                .execute(&mut *self.0)
                .await
                .map_err(|e| error!("Couldn't insert tag {}={}: {}", k, v, e))
                .unwrap();
            ntags += 1;
        }

        Ok(ntags)
    }

    async fn start_session(&mut self) -> Result<u32, Error> {
        Ok(sqlx::query("INSERT INTO Sessions (start, last) VALUES (strftime('%s', 'now'), strftime('%s', 'now')) RETURNING id")
        .fetch_one(&mut *self.0)
        .await?.get(0))
    }

    async fn update_session_time(&mut self, id: u32) -> Result<(), Error> {
        sqlx::query("UPDATE Sessions SET last = strftime('%s', 'now') WHERE id = ?")
            .bind(id)
            .execute(&mut *self.0)
            .await?;
        Ok(())
    }
}

const USER_ID_COOKIE_KEY: &str = "user_id";

struct LoggedInGuard(String);

#[rocket::async_trait]
impl<'r> FromRequest<'r> for LoggedInGuard {
    type Error = Error;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        let cookies = req.cookies();
        if let Some(uc) = cookies.get_private(USER_ID_COOKIE_KEY) {
            Outcome::Success(LoggedInGuard(uc.value().to_string()))
        } else {
            Outcome::Forward(())
        }
    }
}

struct HeadersGuard<'h>(HeaderMap<'h>);

#[rocket::async_trait]
impl<'r> FromRequest<'r> for HeadersGuard<'r> {
    type Error = Error;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        Outcome::Success(HeadersGuard(req.headers().clone()))
    }
}

#[derive(Responder)]
enum LoginResponse {
    // use templates later.
    #[response(status = 200, content_type = "html")]
    Ok { body: Template },
    #[response(status = 302, content_type = "html")]
    LoggedInAlready { redirect: Redirect },
}

#[derive(Responder)]
#[response(content_type = "text", status = 500)]
struct InternalServerError {
    body: String,
}

#[rocket::get("/login")]
async fn route_login_form<'r>(
    flash: Option<FlashMessage<'_>>,
    cookies: &CookieJar<'_>,
) -> Result<LoginResponse, InternalServerError> {
    let f;
    if let Some(ref flash) = flash {
        f = Some(format!("{}: {}", flash.kind(), flash.message()));
    } else {
        f = None;
    }
    if let Some(cookie) = cookies.get_private(USER_ID_COOKIE_KEY) {
        Ok(LoginResponse::LoggedInAlready {
            redirect: Redirect::to(rocket::uri!("/")),
        })
    } else {
        Ok(LoginResponse::Ok {
            body: Template::render("login", context![flash: f]),
        })
    }
}

#[rocket::post("/logout")]
fn route_logout(cookies: &CookieJar<'_>) -> Redirect {
    if let Some(cookie) = cookies.get_private(USER_ID_COOKIE_KEY) {
        cookies.remove_private(cookie);
    }
    Redirect::to(rocket::uri!("/"))
}

#[derive(rocket::FromForm)]
struct LoginForm {
    username: String,
    password: String,
}

#[rocket::post("/login", data = "<login>")]
async fn route_login_post(
    mut db: Connection<ConfigDB>,
    cookies: &CookieJar<'_>,
    login: Form<LoginForm>,
) -> Flash<Redirect> {
    // TO DO: look up user in database.
    let mut conn = ConfigDBSession(&mut db);
    match conn
        .check_user_password(&login.username, &login.password)
        .await
    {
        Ok(true) => {
            let c = Cookie::new(USER_ID_COOKIE_KEY, login.username.clone());
            cookies.add_private(c);
            Flash::success(Redirect::to(rocket::uri!("/")), "Successfully logged in.")
        }
        Ok(false) => Flash::error(
            Redirect::to(rocket::uri!("/login")),
            "User/password not found",
        ),
        Err(e) => Flash::error(
            Redirect::to(rocket::uri!("/login")),
            format!("User/password lookup failed: {}", e),
        ),
    }
}

#[rocket::get("/", rank = 1)]
async fn route_index_loggedin(lig: LoggedInGuard, flash: Option<FlashMessage<'_>>) -> Template {
    let f;
    if let Some(ref flash) = flash {
        f = Some(format!("{}: {}", flash.kind(), flash.message()));
    } else {
        f = None;
    }
    Template::render("index", context![loggedin: true, username: lig.0, flash: f])
}

#[rocket::get("/", rank = 2)]
async fn route_index_loggedout(flash: Option<FlashMessage<'_>>) -> Template {
    let f;
    if let Some(ref flash) = flash {
        f = Some(format!("{}: {}", flash.kind(), flash.message()));
    } else {
        f = None;
    }
    Template::render("index", context![loggedin: false, flash: f])
}

// TODO: ignore requests when logged in.
#[rocket::get("/log?<host>&<status>&<path>&<pagename>&<referer>&<tags>")]
async fn route_log(
    mut conn: Connection<LogsDB>,
    cookies: &CookieJar<'_>,
    host: Option<String>,
    status: Option<u32>,
    path: Option<String>,
    pagename: Option<String>,
    referer: Option<String>,
    tags: Vec<String>,
    config: &rocket::State<CustomConfig>,
    ip: IpAddr,
    headers: HeadersGuard<'_>,
) -> (Status, Either<NamedFile, &'static str>) {
    let mut conn = LogsDBSession(&mut conn);
    let mut session_id = None;
    // Get session ID from cookie, or start new session.
    if let Some(sessioncookie) = cookies.get_private("analyrics_session") {
        if let Ok(id) = u32::from_str_radix(sessioncookie.value(), 10) {
            session_id = Some(id);
            match conn.update_session_time(id).await {
                Ok(()) => {}
                Err(e) => error!("Couldn't update session time: {}", e),
            }
        }
    }
    if session_id.is_none() {
        match conn.start_session().await {
            Ok(id) => {
                session_id = Some(id);
                let c = Cookie::build("analyrics_session", id.to_string())
                    .max_age(time::Duration::hours(12))
                    .finish();
                cookies.add_private(c);
            }
            Err(e) => error!("Couldn't start session: {}", e),
        }
    }

    let ntags = tags.len() as u32;
    let ua = headers.0.get_one("user-agent").unwrap_or("");
    let ip = ip.to_string();
    let ip = headers.0.get_one("x-real-ip").unwrap_or(&ip);
    let host: String = host.unwrap_or(
        headers
            .0
            .get("host")
            .take(1)
            .map(|s| s.to_string())
            .collect::<Vec<String>>()
            .pop()
            .unwrap_or(String::new()),
    );
    match conn
        .log_request(
            session_id,
            ip,
            host,
            path.unwrap_or(String::new()),
            status.unwrap_or(200),
            pagename,
            referer,
            ua,
            ntags,
        )
        .await
    {
        Err(e) => error!("Couldn't log request: {}", e),
        Ok(id) => {
            conn.log_tags(id, tags.iter()).await.ok();
        }
    }

    if let Ok(f) = NamedFile::open(Path::new(config.asset_path.as_str()).join("pixel.png")).await {
        (Status::Ok, Either::Left(f))
    } else {
        (Status::Ok, Either::Right(""))
    }
}

#[derive(rocket::serde::Deserialize)]
#[serde(crate = "rocket::serde")]
struct CustomConfig {
    asset_path: String,
}

#[rocket::launch]
fn rocketmain() -> _ {
    env_logger::init();

    rocket::build()
        .attach(ConfigDB::init())
        .attach(LogsDB::init())
        .attach(Template::fairing())
        .attach(rocket::fairing::AdHoc::config::<CustomConfig>())
        .mount(
            "/",
            rocket::routes![
                route_index_loggedin,
                route_index_loggedout,
                route_logout,
                route_login_form,
                route_login_post,
                route_log,
            ],
        )
}