changeset 8:6e4ad5da0f58

Restructure database access
author Lewin Bormann <lbo@spheniscida.de>
date Sun, 10 Jul 2022 09:29:04 -0700
parents f06a74b14e94
children 71c5459ec0cc
files src/main.rs
diffstat 1 files changed, 107 insertions(+), 94 deletions(-) [+]
line wrap: on
line diff
--- a/src/main.rs	Sat Jul 09 21:10:10 2022 -0700
+++ b/src/main.rs	Sun Jul 10 09:29:04 2022 -0700
@@ -39,92 +39,100 @@
 #[database("sqlite_main")]
 struct ConfigDB(PoolType);
 
-async fn check_user_password<S: AsRef<str>>(
-    conn: &mut PoolConnection<DBType>,
-    user: S,
-    password: S,
-) -> Result<bool, Error> {
-    // TODO: salt passwords.
-    let pwdhash = sha256::digest(password.as_ref());
-    let q = sqlx::query("SELECT username FROM users WHERE username = ? AND password_hash = ?")
-        .bind(user.as_ref())
-        .bind(pwdhash);
-    let result = conn.fetch_all(q).await?;
-    Ok(result.len() == 1)
+struct ConfigDBSession<'p, DB: sqlx::Database>(&'p mut sqlx::pool::PoolConnection<DB>);
+
+impl<'p> ConfigDBSession<'p, Sqlite> {
+    async fn check_user_password<S: AsRef<str>>(
+        &mut self,
+        user: S,
+        password: S,
+    ) -> Result<bool, Error> {
+        // TODO: salt passwords.
+        let pwdhash = sha256::digest(password.as_ref());
+        let q = sqlx::query("SELECT username FROM users WHERE username = ? AND password_hash = ?")
+            .bind(user.as_ref())
+            .bind(pwdhash);
+        let result = self.0.fetch_all(q).await?;
+        Ok(result.len() == 1)
+    }
 }
 
 #[derive(Database)]
 #[database("sqlite_logs")]
 struct LogsDB(PoolType);
 
-async fn log_request<
-    S1: AsRef<str>,
-    S2: AsRef<str>,
-    S3: AsRef<str>,
-    S4: AsRef<str>,
-    S5: AsRef<str>,
-    S6: AsRef<str>,
->(
-    conn: &mut PoolConnection<DBType>,
-    session: Option<u32>,
-    ip: S1,
-    domain: S2,
-    path: S3,
-    status: u32,
-    page: Option<S4>,
-    refer: Option<S5>,
-    ua: S6,
-    ntags: u32,
-) -> Result<u32, Error> {
-    let q = sqlx::query::<DBType>("INSERT INTO RequestLog (session, ip, atime, domain, path, status, pagename, refer, ua, ntags) VALUES (?, ?, strftime('%s', 'now'), ?, ?, ?, ?, ?, ?, ?) RETURNING id");
-    let q = q
-        .bind(session)
-        .bind(ip.as_ref())
-        .bind(domain.as_ref())
-        .bind(path.as_ref())
-        .bind(status)
-        .bind(page.map(|s| s.as_ref().to_string()))
-        .bind(refer.map(|s| s.as_ref().to_string()))
-        .bind(ua.as_ref())
-        .bind(ntags);
-    let row: u32 = q.fetch_one(conn).await?.get(0);
-    Ok(row)
-}
+struct LogsDBSession<'p, DB: sqlx::Database>(&'p mut sqlx::pool::PoolConnection<DB>);
 
-async fn log_tags<S: AsRef<str>, I: Iterator<Item = S>>(
-    conn: &mut PoolConnection<DBType>,
-    requestid: u32,
-    tags: I,
-) -> Result<usize, Error> {
-    let mut ntags = 0;
-    for tag in tags {
-        let (k, v) = tag.as_ref().split_once("=").unwrap_or((tag.as_ref(), ""));
-        sqlx::query("INSERT INTO RequestTags (requestid, key, value) VALUES (?, ?, ?)")
-            .bind(requestid)
-            .bind(k)
-            .bind(v)
-            .execute(&mut *conn)
-            .await
-            .map_err(|e| error!("Couldn't insert tag {}={}: {}", k, v, e))
-            .unwrap();
-        ntags += 1;
+impl<'p> LogsDBSession<'p, Sqlite> {
+    async fn log_request<
+        S1: AsRef<str>,
+        S2: AsRef<str>,
+        S3: AsRef<str>,
+        S4: AsRef<str>,
+        S5: AsRef<str>,
+        S6: AsRef<str>,
+    >(
+        &mut self,
+        session: Option<u32>,
+        ip: S1,
+        domain: S2,
+        path: S3,
+        status: u32,
+        page: Option<S4>,
+        refer: Option<S5>,
+        ua: S6,
+        ntags: u32,
+    ) -> Result<u32, Error> {
+        let q = sqlx::query::<DBType>("INSERT INTO RequestLog (session, ip, atime, domain, path, status, pagename, refer, ua, ntags) VALUES (?, ?, strftime('%s', 'now'), ?, ?, ?, ?, ?, ?, ?) RETURNING id");
+        let q = q
+            .bind(session)
+            .bind(ip.as_ref())
+            .bind(domain.as_ref())
+            .bind(path.as_ref())
+            .bind(status)
+            .bind(page.map(|s| s.as_ref().to_string()))
+            .bind(refer.map(|s| s.as_ref().to_string()))
+            .bind(ua.as_ref())
+            .bind(ntags);
+        let row: u32 = q.fetch_one(&mut *self.0).await?.get(0);
+        Ok(row)
     }
 
-    Ok(ntags)
-}
+    async fn log_tags<S: AsRef<str>, I: Iterator<Item = S>>(
+        &mut self,
+        requestid: u32,
+        tags: I,
+    ) -> Result<usize, Error> {
+        let mut ntags = 0;
+        for tag in tags {
+            let (k, v) = tag.as_ref().split_once("=").unwrap_or((tag.as_ref(), ""));
+            sqlx::query("INSERT INTO RequestTags (requestid, key, value) VALUES (?, ?, ?)")
+                .bind(requestid)
+                .bind(k)
+                .bind(v)
+                .execute(&mut *self.0)
+                .await
+                .map_err(|e| error!("Couldn't insert tag {}={}: {}", k, v, e))
+                .unwrap();
+            ntags += 1;
+        }
 
-async fn start_session(conn: &mut PoolConnection<DBType>) -> Result<u32, Error> {
-    Ok(sqlx::query("INSERT INTO Sessions (start, last) VALUES (strftime('%s', 'now'), strftime('%s', 'now')) RETURNING id")
-        .fetch_one(conn)
+        Ok(ntags)
+    }
+
+    async fn start_session(&mut self) -> Result<u32, Error> {
+        Ok(sqlx::query("INSERT INTO Sessions (start, last) VALUES (strftime('%s', 'now'), strftime('%s', 'now')) RETURNING id")
+        .fetch_one(&mut *self.0)
         .await?.get(0))
-}
+    }
 
-async fn update_session_time(conn: &mut PoolConnection<DBType>, id: u32) -> Result<(), Error> {
-    sqlx::query("UPDATE Sessions SET last = strftime('%s', 'now') WHERE id = ?")
-        .bind(id)
-        .execute(conn)
-        .await?;
-    Ok(())
+    async fn update_session_time(&mut self, id: u32) -> Result<(), Error> {
+        sqlx::query("UPDATE Sessions SET last = strftime('%s', 'now') WHERE id = ?")
+            .bind(id)
+            .execute(&mut *self.0)
+            .await?;
+        Ok(())
+    }
 }
 
 const USER_ID_COOKIE_KEY: &str = "user_id";
@@ -209,13 +217,16 @@
 
 #[rocket::post("/login", data = "<login>")]
 async fn route_login_post(
-    db: Connection<ConfigDB>,
+    mut db: Connection<ConfigDB>,
     cookies: &CookieJar<'_>,
     login: Form<LoginForm>,
 ) -> Flash<Redirect> {
     // TO DO: look up user in database.
-    let mut db = db.into_inner();
-    match check_user_password(&mut db, &login.username, &login.password).await {
+    let mut conn = ConfigDBSession(&mut db);
+    match conn
+        .check_user_password(&login.username, &login.password)
+        .await
+    {
         Ok(true) => {
             let c = Cookie::new(USER_ID_COOKIE_KEY, login.username.clone());
             cookies.add_private(c);
@@ -269,19 +280,20 @@
     ip: IpAddr,
     headers: HeadersGuard<'_>,
 ) -> (Status, Either<NamedFile, &'static str>) {
+    let mut conn = LogsDBSession(&mut conn);
     let mut session_id = None;
     // Get session ID from cookie, or start new session.
     if let Some(sessioncookie) = cookies.get_private("analyrics_session") {
         if let Ok(id) = u32::from_str_radix(sessioncookie.value(), 10) {
             session_id = Some(id);
-            match update_session_time(&mut conn, id).await {
+            match conn.update_session_time(id).await {
                 Ok(()) => {}
                 Err(e) => error!("Couldn't update session time: {}", e),
             }
         }
     }
     if session_id.is_none() {
-        match start_session(&mut conn).await {
+        match conn.start_session().await {
             Ok(id) => {
                 session_id = Some(id);
                 let c = Cookie::build("analyrics_session", id.to_string())
@@ -293,9 +305,10 @@
         }
     }
 
-    let mut conn = conn.into_inner();
     let ntags = tags.len() as u32;
     let ua = headers.0.get_one("user-agent").unwrap_or("");
+    let ip = ip.to_string();
+    let ip = headers.0.get_one("x-real-ip").unwrap_or(&ip);
     let host: String = host.unwrap_or(
         headers
             .0
@@ -306,23 +319,23 @@
             .pop()
             .unwrap_or(String::new()),
     );
-    match log_request(
-        &mut conn,
-        session_id,
-        ip.to_string(),
-        host,
-        path.unwrap_or(String::new()),
-        status.unwrap_or(200),
-        pagename,
-        referer,
-        ua,
-        ntags,
-    )
-    .await
+    match conn
+        .log_request(
+            session_id,
+            ip,
+            host,
+            path.unwrap_or(String::new()),
+            status.unwrap_or(200),
+            pagename,
+            referer,
+            ua,
+            ntags,
+        )
+        .await
     {
         Err(e) => error!("Couldn't log request: {}", e),
         Ok(id) => {
-            log_tags(&mut conn, id, tags.iter()).await.ok();
+            conn.log_tags(id, tags.iter()).await.ok();
         }
     }