changeset 35:097f1c1c5f2b

Refactor database code part I
author Lewin Bormann <lbo@spheniscida.de>
date Thu, 03 Dec 2020 08:12:25 +0100
parents 9966460e2930
children ebdb9c50adb1
files src/db.rs src/main.rs src/notifier.rs
diffstat 3 files changed, 93 insertions(+), 73 deletions(-) [+]
line wrap: on
line diff
--- a/src/db.rs	Thu Dec 03 07:51:38 2020 +0100
+++ b/src/db.rs	Thu Dec 03 08:12:25 2020 +0100
@@ -1,51 +1,85 @@
 use crate::types;
 
-/// Queries for at most `limit` rows since entry ID `last`.
-pub fn check_for_new_rows(
-    db: &postgres::Connection,
-    name: &str,
-    secret: Option<&str>,
-    last: &Option<i32>,
-    limit: &Option<i64>,
-) -> Option<(types::GeoJSON, i32)> {
-    let mut returnable = types::GeoJSON::new();
-    let check_for_new = db.prepare_cached(
-        r"SELECT id, t, lat, long, spd, ele FROM geohub.geodata
-        WHERE (client = $1) and (id > $2) AND (secret = public.digest($3, 'sha256') or secret is null)
-        ORDER BY id DESC
-        LIMIT $4").unwrap(); // Must succeed.
+/// Managed by Rocket.
+#[rocket_contrib::database("geohub")]
+pub struct DBConn(postgres::Connection);
+
+/// For requests from in- or outside a request handler.
+pub struct DBQuery<'a>(pub &'a postgres::Connection);
 
-    let last = last.unwrap_or(0);
-    let limit = limit.unwrap_or(256);
+impl<'a> DBQuery<'a> {
+    /// Fetch records and format as JSON
+    pub fn retrieve_json(
+        &self,
+        name: &str,
+        from_ts: chrono::DateTime<chrono::Utc>,
+        to_ts: chrono::DateTime<chrono::Utc>,
+        secret: &str,
+        limit: i64,
+    ) -> Result<types::GeoJSON, postgres::Error> {
+        let mut returnable = types::GeoJSON::new();
+        let stmt = self.0.prepare_cached(
+            r"SELECT t, lat, long, spd, ele FROM geohub.geodata
+        WHERE (client = $1) and (t between $2 and $3) AND (secret = public.digest($4, 'sha256') or secret is null)
+        ORDER BY t ASC
+        LIMIT $5").unwrap(); // Must succeed.
+        let rows = stmt.query(&[&name, &from_ts, &to_ts, &secret, &limit])?;
+        returnable.reserve_features(rows.len());
+        for row in rows.iter() {
+            let (ts, lat, long, spd, ele) =
+                (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4));
+            returnable.push_feature(types::geofeature_from_row(ts, lat, long, spd, ele));
+        }
+        Ok(returnable)
+    }
 
-    let rows = check_for_new.query(&[&name, &last, &secret, &limit]);
-    if let Ok(rows) = rows {
-        // If there are unknown entries, return those.
-        if rows.len() > 0 {
-            returnable.reserve_features(rows.len());
-            let mut last = 0;
+    /// Queries for at most `limit` rows since entry ID `last`.
+    pub fn check_for_new_rows(
+        &self,
+        name: &str,
+        secret: Option<&str>,
+        last: &Option<i32>,
+        limit: &Option<i64>,
+    ) -> Option<(types::GeoJSON, i32)> {
+        let mut returnable = types::GeoJSON::new();
+        let check_for_new = self.0.prepare_cached(
+            r"SELECT id, t, lat, long, spd, ele FROM geohub.geodata
+            WHERE (client = $1) and (id > $2) AND (secret = public.digest($3, 'sha256') or secret is null)
+            ORDER BY id DESC
+            LIMIT $4").unwrap(); // Must succeed.
+
+        let last = last.unwrap_or(0);
+        let limit = limit.unwrap_or(256);
 
-            for row in rows.iter() {
-                let (id, ts, lat, long, spd, ele) = (
-                    row.get(0),
-                    row.get(1),
-                    row.get(2),
-                    row.get(3),
-                    row.get(4),
-                    row.get(5),
-                );
-                returnable.push_feature(types::geofeature_from_row(ts, lat, long, spd, ele));
-                if id > last {
-                    last = id;
+        let rows = check_for_new.query(&[&name, &last, &secret, &limit]);
+        if let Ok(rows) = rows {
+            // If there are unknown entries, return those.
+            if rows.len() > 0 {
+                returnable.reserve_features(rows.len());
+                let mut last = 0;
+
+                for row in rows.iter() {
+                    let (id, ts, lat, long, spd, ele) = (
+                        row.get(0),
+                        row.get(1),
+                        row.get(2),
+                        row.get(3),
+                        row.get(4),
+                        row.get(5),
+                    );
+                    returnable.push_feature(types::geofeature_from_row(ts, lat, long, spd, ele));
+                    if id > last {
+                        last = id;
+                    }
                 }
+
+                return Some((returnable, last));
             }
-
-            return Some((returnable, last));
+            return None;
+        } else {
+            // For debugging.
+            rows.unwrap();
         }
         return None;
-    } else {
-        // For debugging.
-        rows.unwrap();
     }
-    return None;
 }
--- a/src/main.rs	Thu Dec 03 07:51:38 2020 +0100
+++ b/src/main.rs	Thu Dec 03 08:12:25 2020 +0100
@@ -12,26 +12,20 @@
 use postgres;
 use rocket;
 
-#[rocket_contrib::database("geohub")]
-struct DBConn(postgres::Connection);
-
 /// Almost like retrieve/json, but sorts in descending order and doesn't work with intervals (only
 /// limit). Used for backfilling recent points in the UI.
 #[rocket::get("/geo/<name>/retrieve/last?<secret>&<last>&<limit>")]
 fn retrieve_last(
-    db: DBConn,
+    db: db::DBConn,
     name: String,
     secret: Option<String>,
     last: Option<i32>,
     limit: Option<i64>,
 ) -> rocket_contrib::json::Json<LiveUpdate> {
-    if let Some((geojson, newlast)) = db::check_for_new_rows(
-        &db.0,
-        &name,
-        secret.as_ref().map(|s| s.as_str()),
-        &last,
-        &limit,
-    ) {
+    let db = db::DBQuery(&db.0);
+    if let Some((geojson, newlast)) =
+        db.check_for_new_rows(&name, secret.as_ref().map(|s| s.as_str()), &last, &limit)
+    {
         return rocket_contrib::json::Json(LiveUpdate {
             typ: "GeoHubUpdate".into(),
             last: Some(newlast),
@@ -106,13 +100,14 @@
 /// Retrieve GeoJSON data.
 #[rocket::get("/geo/<name>/retrieve/json?<secret>&<from>&<to>&<limit>")]
 fn retrieve_json(
-    db: DBConn,
+    db: db::DBConn,
     name: String,
     secret: Option<String>,
     from: Option<String>,
     to: Option<String>,
     limit: Option<i64>,
 ) -> rocket_contrib::json::Json<types::GeoJSON> {
+    let db = db::DBQuery(&db.0);
     let from_ts =
         from.and_then(util::flexible_timestamp_parse)
             .unwrap_or(chrono::DateTime::from_utc(
@@ -123,24 +118,14 @@
         .and_then(util::flexible_timestamp_parse)
         .unwrap_or(chrono::Utc::now());
     let limit = limit.unwrap_or(16384);
+    let secret = secret.as_ref().map(|s| s.as_str()).unwrap_or("");
 
-    let mut returnable = types::GeoJSON::new();
-    let stmt = db.0.prepare_cached(
-        r"SELECT t, lat, long, spd, ele FROM geohub.geodata
-        WHERE (client = $1) and (t between $2 and $3) AND (secret = public.digest($4, 'sha256') or secret is null)
-        ORDER BY t ASC
-        LIMIT $5").unwrap(); // Must succeed.
-    let rows = stmt.query(&[&name, &from_ts, &to_ts, &secret, &limit]);
-    if let Ok(rows) = rows {
-        returnable.reserve_features(rows.len());
-        for row in rows.iter() {
-            let (ts, lat, long, spd, ele) =
-                (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4));
-            returnable.push_feature(types::geofeature_from_row(ts, lat, long, spd, ele));
-        }
+    if let Ok(json) = db.retrieve_json(name.as_str(), from_ts, to_ts, secret, limit) {
+        return rocket_contrib::json::Json(json);
     }
 
-    rocket_contrib::json::Json(returnable)
+    // Todo: Use custom database error return
+    rocket_contrib::json::Json(types::GeoJSON::new())
 }
 
 /// Ingest geo data.
@@ -149,7 +134,7 @@
 /// secret can be used to protect points.
 #[rocket::post("/geo/<name>/log?<lat>&<longitude>&<time>&<s>&<ele>&<secret>")]
 fn log(
-    db: DBConn,
+    db: db::DBConn,
     name: String,
     lat: f64,
     longitude: f64,
@@ -196,7 +181,7 @@
     };
 
     rocket::ignite()
-        .attach(DBConn::fairing())
+        .attach(db::DBConn::fairing())
         .manage(send)
         .attach(rocket::fairing::AdHoc::on_attach(
             "Database Notifications",
--- a/src/notifier.rs	Thu Dec 03 07:51:38 2020 +0100
+++ b/src/notifier.rs	Thu Dec 03 08:12:25 2020 +0100
@@ -40,6 +40,7 @@
     const TICK_MILLIS: u32 = 500;
 
     let mut clients: HashMap<String, Vec<NotifyRequest>> = HashMap::new();
+    let db = db::DBQuery(&db);
 
     fn listen(db: &postgres::Connection, client: &str, secret: &str) -> postgres::Result<u64> {
         let n = db
@@ -65,7 +66,7 @@
                 let secret = nrq.secret.as_ref().map(|s| s.as_str()).unwrap_or("");
                 let chan_name = ids::channel_name(nrq.client.as_str(), secret);
                 if !clients.contains_key(&chan_name) {
-                    listen(&db, &nrq.client, secret).ok();
+                    listen(db.0, &nrq.client, secret).ok();
                 }
                 clients.entry(chan_name).or_insert(vec![]).push(nrq);
             } else {
@@ -75,18 +76,18 @@
 
         // Drain notifications from the database.
         // Also provide updated rows to the client.
-        let notifications = db.notifications();
+        let notifications = db.0.notifications();
         let mut iter = notifications.timeout_iter(time::Duration::new(0, TICK_MILLIS * 1_000_000));
         let mut count = 0;
 
         while let Ok(Some(notification)) = iter.next() {
             let chan = notification.channel;
             let (client, secret) = ids::client_secret(chan.as_str());
-            unlisten(&db, &chan).ok();
+            unlisten(db.0, &chan).ok();
 
             // These queries use the primary key index returning one row only and will be quite fast.
             // Still: One query per client.
-            let rows = db::check_for_new_rows(&db, client, Some(secret), &None, &Some(1));
+            let rows = db.check_for_new_rows(client, Some(secret), &None, &Some(1));
             if let Some((geo, last)) = rows {
                 for request in clients.remove(&chan).unwrap_or(vec![]) {
                     request