changeset 30:f697356d93ae

Split up code into modules
author Lewin Bormann <lbo@spheniscida.de>
date Thu, 03 Dec 2020 00:01:48 +0100
parents 86704f1e624d
children 7559c75bb43c
files TODO src/db.rs src/ids.rs src/main.rs src/notifier.rs src/types.rs
diffstat 6 files changed, 282 insertions(+), 267 deletions(-) [+]
line wrap: on
line diff
--- a/TODO	Wed Dec 02 23:25:34 2020 +0100
+++ b/TODO	Thu Dec 03 00:01:48 2020 +0100
@@ -1,5 +1,7 @@
 GENERAL
 
+* Proper HTTP status for invalid secret/client
+
 FEATURES
 
 * GPX/json export (with UI + API)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/db.rs	Thu Dec 03 00:01:48 2020 +0100
@@ -0,0 +1,64 @@
+use crate::types;
+
+/// Queries for at most `limit` rows since entry ID `last`.
+pub fn check_for_new_rows(
+    db: &postgres::Connection,
+    name: &str,
+    secret: Option<&str>,
+    last: &Option<i32>,
+    limit: &Option<i64>,
+) -> Option<(types::GeoJSON, i32)> {
+    let mut returnable = types::GeoJSON {
+        typ: "FeatureCollection".into(),
+        features: vec![],
+    };
+    let check_for_new = db.prepare_cached(
+        r"SELECT id, t, lat, long, spd, ele FROM geohub.geodata
+        WHERE (client = $1) and (id > $2) AND (secret = public.digest($3, 'sha256') or secret is null)
+        ORDER BY id DESC
+        LIMIT $4").unwrap(); // Must succeed.
+
+    let last = last.unwrap_or(0);
+    let limit = limit.unwrap_or(256);
+
+    let rows = check_for_new.query(&[&name, &last, &secret, &limit]);
+    if let Ok(rows) = rows {
+        // If there are unknown entries, return those.
+        if rows.len() > 0 {
+            returnable.features = Vec::with_capacity(rows.len());
+            let mut last = 0;
+
+            for row in rows.iter() {
+                let (id, ts, lat, long, spd, ele): (
+                    i32,
+                    chrono::DateTime<chrono::Utc>,
+                    Option<f64>,
+                    Option<f64>,
+                    Option<f64>,
+                    Option<f64>,
+                ) = (
+                    row.get(0),
+                    row.get(1),
+                    row.get(2),
+                    row.get(3),
+                    row.get(4),
+                    row.get(5),
+                );
+                returnable
+                    .features
+                    .push(types::geofeature_from_row(ts, lat, long, spd, ele));
+                if id > last {
+                    last = id;
+                }
+            }
+
+            return Some((returnable, last));
+        }
+        return None;
+    } else {
+        // For debugging.
+        rows.unwrap();
+    }
+    return None;
+}
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/ids.rs	Thu Dec 03 00:01:48 2020 +0100
@@ -0,0 +1,24 @@
+/// Check if client name and secret are acceptable.
+pub fn name_and_secret_acceptable(client: &str, secret: Option<&str>) -> bool {
+    !(client.chars().any(|c| !c.is_ascii_alphanumeric())
+        || secret
+            .unwrap_or("")
+            .chars()
+            .any(|c| !c.is_ascii_alphanumeric()))
+}
+
+/// Build a channel name from a client name and secret.
+pub fn channel_name(client: &str, secret: &str) -> String {
+    // The log handler should check this.
+    assert!(secret.find('_').is_none());
+    format!("geohubclient_update_{}_{}", client, secret)
+}
+
+/// Extract client name and secret from the database channel name.
+pub fn client_secret(channel_name: &str) -> (&str, &str) {
+    // Channel name is like geohubclient_update_<client>_<secret>
+    let parts = channel_name.split('_').collect::<Vec<&str>>();
+    assert!(parts.len() == 4);
+    return (parts[2], parts[3]);
+}
+
--- a/src/main.rs	Wed Dec 02 23:25:34 2020 +0100
+++ b/src/main.rs	Thu Dec 03 00:01:48 2020 +0100
@@ -1,17 +1,18 @@
 #![feature(proc_macro_hygiene, decl_macro)]
 
-use std::collections::HashMap;
+mod db;
+mod ids;
+mod notifier;
+mod types;
+
+use std::time;
 use std::sync::{mpsc, Arc, Mutex};
-use std::thread;
-use std::time;
 
 use postgres;
 use rocket;
 
 use chrono::TimeZone;
 
-use fallible_iterator::FallibleIterator;
-use std::iter::Iterator;
 
 #[rocket_contrib::database("geohub")]
 struct DBConn(postgres::Connection);
@@ -39,129 +40,15 @@
     None
 }
 
-/// Fetch geodata as JSON.
-///
-#[derive(serde::Serialize, Debug, Clone)]
-struct GeoProperties {
-    time: chrono::DateTime<chrono::Utc>,
-    altitude: Option<f64>,
-    speed: Option<f64>,
-}
-
-#[derive(serde::Serialize, Debug, Clone)]
-struct GeoGeometry {
-    #[serde(rename = "type")]
-    typ: String, // always "Point"
-    coordinates: Vec<f64>, // always [long, lat]
-}
-
-#[derive(serde::Serialize, Debug, Clone)]
-struct GeoFeature {
-    #[serde(rename = "type")]
-    typ: String, // always "Feature"
-    properties: GeoProperties,
-    geometry: GeoGeometry,
-}
-
-fn geofeature_from_row(
-    ts: chrono::DateTime<chrono::Utc>,
-    lat: Option<f64>,
-    long: Option<f64>,
-    spd: Option<f64>,
-    ele: Option<f64>,
-) -> GeoFeature {
-    GeoFeature {
-        typ: "Feature".into(),
-        properties: GeoProperties {
-            time: ts,
-            altitude: ele,
-            speed: spd,
-        },
-        geometry: GeoGeometry {
-            typ: "Point".into(),
-            coordinates: vec![long.unwrap_or(0.), lat.unwrap_or(0.)],
-        },
-    }
-}
-
-#[derive(serde::Serialize, Debug, Clone)]
-struct GeoJSON {
-    #[serde(rename = "type")]
-    typ: String, // always "FeatureCollection"
-    features: Vec<GeoFeature>,
-}
-
 #[derive(serde::Serialize, Debug)]
 struct LiveUpdate {
     #[serde(rename = "type")]
     typ: String, // always "GeoHubUpdate"
     last: Option<i32>, // page token -- send in next request!
-    geo: Option<GeoJSON>,
+    geo: Option<types::GeoJSON>,
     error: Option<String>,
 }
 
-/// Queries for at most `limit` rows since entry ID `last`.
-fn check_for_new_rows(
-    db: &postgres::Connection,
-    name: &str,
-    secret: Option<&str>,
-    last: &Option<i32>,
-    limit: &Option<i64>,
-) -> Option<(GeoJSON, i32)> {
-    let mut returnable = GeoJSON {
-        typ: "FeatureCollection".into(),
-        features: vec![],
-    };
-    let check_for_new = db.prepare_cached(
-        r"SELECT id, t, lat, long, spd, ele FROM geohub.geodata
-        WHERE (client = $1) and (id > $2) AND (secret = public.digest($3, 'sha256') or secret is null)
-        ORDER BY id DESC
-        LIMIT $4").unwrap(); // Must succeed.
-
-    let last = last.unwrap_or(0);
-    let limit = limit.unwrap_or(256);
-
-    let rows = check_for_new.query(&[&name, &last, &secret, &limit]);
-    if let Ok(rows) = rows {
-        // If there are unknown entries, return those.
-        if rows.len() > 0 {
-            returnable.features = Vec::with_capacity(rows.len());
-            let mut last = 0;
-
-            for row in rows.iter() {
-                let (id, ts, lat, long, spd, ele): (
-                    i32,
-                    chrono::DateTime<chrono::Utc>,
-                    Option<f64>,
-                    Option<f64>,
-                    Option<f64>,
-                    Option<f64>,
-                ) = (
-                    row.get(0),
-                    row.get(1),
-                    row.get(2),
-                    row.get(3),
-                    row.get(4),
-                    row.get(5),
-                );
-                returnable
-                    .features
-                    .push(geofeature_from_row(ts, lat, long, spd, ele));
-                if id > last {
-                    last = id;
-                }
-            }
-
-            return Some((returnable, last));
-        }
-        return None;
-    } else {
-        // For debugging.
-        rows.unwrap();
-    }
-    return None;
-}
-
 /// Almost like retrieve/json, but sorts in descending order and doesn't work with intervals (only
 /// limit). Used for backfilling recent points in the UI.
 #[rocket::get("/geo/<name>/retrieve/last?<secret>&<last>&<limit>")]
@@ -172,7 +59,7 @@
     last: Option<i32>,
     limit: Option<i64>,
 ) -> rocket_contrib::json::Json<LiveUpdate> {
-    if let Some((geojson, newlast)) = check_for_new_rows(
+    if let Some((geojson, newlast)) = db::check_for_new_rows(
         &db.0,
         &name,
         secret.as_ref().map(|s| s.as_str()),
@@ -198,12 +85,12 @@
 /// Only one point is returned. To retrieve a history of points, call retrieve_last.
 #[rocket::get("/geo/<name>/retrieve/live?<secret>&<timeout>")]
 fn retrieve_live(
-    notify_manager: rocket::State<SendableSender<NotifyRequest>>,
+    notify_manager: rocket::State<notifier::SendableSender<notifier::NotifyRequest>>,
     name: String,
     secret: Option<String>,
     timeout: Option<u64>,
-) -> rocket_contrib::json::Json<LiveUpdate> {
-    if !name_and_secret_acceptable(name.as_str(), secret.as_ref().map(|s| s.as_str())) {
+) -> rocket_contrib::json::Json<LiveUpdate>  {
+    if !ids::name_and_secret_acceptable(name.as_str(), secret.as_ref().map(|s| s.as_str())) {
         return rocket_contrib::json::Json(LiveUpdate {
             typ: "GeoHubUpdate".into(),
             last: None,
@@ -214,11 +101,11 @@
 
     // Ask the notify thread to tell us when there is an update for this client name and secret.
     let (send, recv) = mpsc::channel();
-    let send = SendableSender {
+    let send = notifier::SendableSender {
         sender: Arc::new(Mutex::new(send)),
     };
 
-    let req = NotifyRequest {
+    let req = notifier::NotifyRequest {
         client: name.clone(),
         secret: secret,
         respond: send,
@@ -250,8 +137,8 @@
     from: Option<String>,
     to: Option<String>,
     limit: Option<i64>,
-) -> rocket_contrib::json::Json<GeoJSON> {
-    let mut returnable = GeoJSON {
+) -> rocket_contrib::json::Json<types::GeoJSON> {
+    let mut returnable = types::GeoJSON {
         typ: "FeatureCollection".into(),
         features: vec![],
     };
@@ -285,7 +172,7 @@
             ) = (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4));
             returnable
                 .features
-                .push(geofeature_from_row(ts, lat, long, spd, ele));
+                .push(types::geofeature_from_row(ts, lat, long, spd, ele));
         }
     }
 
@@ -308,7 +195,7 @@
     ele: Option<f64>,
 ) -> rocket::http::Status {
     // Check that secret and client name are legal.
-    if !name_and_secret_acceptable(name.as_str(), secret.as_ref().map(|s| s.as_str())) {
+    if !ids::name_and_secret_acceptable(name.as_str(), secret.as_ref().map(|s| s.as_str())) {
         return rocket::http::Status::NotAcceptable;
     }
     let mut ts = chrono::Utc::now();
@@ -318,7 +205,7 @@
     let stmt = db.0.prepare_cached("INSERT INTO geohub.geodata (client, lat, long, spd, t, ele, secret) VALUES ($1, $2, $3, $4, $5, $6, public.digest($7, 'sha256'))").unwrap();
     let channel = format!(
         "NOTIFY {}, '{}'",
-        channel_name(name.as_str(), secret.as_ref().unwrap_or(&"".into())),
+        ids::channel_name(name.as_str(), secret.as_ref().unwrap_or(&"".into())),
         name
     );
     let notify = db.0.prepare_cached(channel.as_str()).unwrap();
@@ -338,143 +225,9 @@
         .map_err(|e| rocket::response::status::NotFound(e.to_string()))
 }
 
-/// Request of a web client thread to the notifier thread.
-struct NotifyRequest {
-    client: String,
-    secret: Option<String>,
-    respond: SendableSender<NotifyResponse>,
-}
-
-/// Response from the notifier thread to a web client thread.
-struct NotifyResponse {
-    // The GeoJSON object containing the update and the `last` page token.
-    geo: Option<GeoJSON>,
-    last: Option<i32>,
-}
-
-/// A `Send` sender.
-#[derive(Clone)]
-struct SendableSender<T> {
-    sender: Arc<Mutex<mpsc::Sender<T>>>,
-}
-
-impl<T> SendableSender<T> {
-    fn send(&self, arg: T) -> Result<(), mpsc::SendError<T>> {
-        let s = self.sender.lock().unwrap();
-        s.send(arg)
-    }
-}
-
-/// Check if client name and secret are acceptable.
-fn name_and_secret_acceptable(client: &str, secret: Option<&str>) -> bool {
-    !(client.chars().any(|c| !c.is_ascii_alphanumeric())
-        || secret
-            .unwrap_or("")
-            .chars()
-            .any(|c| !c.is_ascii_alphanumeric()))
-}
-
-/// Build a channel name from a client name and secret.
-fn channel_name(client: &str, secret: &str) -> String {
-    // The log handler should check this.
-    assert!(secret.find('_').is_none());
-    format!("geohubclient_update_{}_{}", client, secret)
-}
-
-/// Extract client name and secret from the database channel name.
-fn client_secret(channel_name: &str) -> (&str, &str) {
-    // Channel name is like geohubclient_update_<client>_<secret>
-    let parts = channel_name.split('_').collect::<Vec<&str>>();
-    assert!(parts.len() == 4);
-    return (parts[2], parts[3]);
-}
-
-/// Listen for notifications in the database and dispatch to waiting clients.
-fn live_notifier_thread(rx: mpsc::Receiver<NotifyRequest>, db: postgres::Connection) {
-    const TICK_MILLIS: u32 = 500;
-
-    let mut clients: HashMap<String, Vec<NotifyRequest>> = HashMap::new();
-
-    fn listen(db: &postgres::Connection, client: &str, secret: &str) -> postgres::Result<u64> {
-        let n = db
-            .execute(
-                &format!("LISTEN {}", channel_name(client, secret).as_str()),
-                &[],
-            )
-            .unwrap();
-        Ok(n)
-    }
-    fn unlisten(db: &postgres::Connection, chan: &str) -> postgres::Result<u64> {
-        let n = db.execute(&format!("UNLISTEN {}", chan), &[]).unwrap();
-        Ok(n)
-    }
-
-    loop {
-        // This loop checks for new messages on rx, then checks for new database notifications, etc.
-
-        // Drain notification requests (clients asking to watch for notifications).
-        // We listen per client and secret to separate clients with different sessions (by secret).
-        loop {
-            if let Ok(nrq) = rx.try_recv() {
-                let secret = nrq.secret.as_ref().map(|s| s.as_str()).unwrap_or("");
-                let chan_name = channel_name(nrq.client.as_str(), secret);
-                if !clients.contains_key(&chan_name) {
-                    listen(&db, &nrq.client, secret).ok();
-                }
-                clients.entry(chan_name).or_insert(vec![]).push(nrq);
-            } else {
-                break;
-            }
-        }
-
-        // Drain notifications from the database.
-        // Also provide updated rows to the client.
-        let notifications = db.notifications();
-        let mut iter = notifications.timeout_iter(time::Duration::new(0, TICK_MILLIS * 1_000_000));
-        let mut count = 0;
-
-        while let Ok(Some(notification)) = iter.next() {
-            let chan = notification.channel;
-            let (client, secret) = client_secret(chan.as_str());
-            unlisten(&db, &chan).ok();
-
-            // These queries use the primary key index returning one row only and will be quite fast.
-            // Still: One query per client.
-            let rows = check_for_new_rows(&db, client, Some(secret), &None, &Some(1));
-            if let Some((geo, last)) = rows {
-                for request in clients.remove(&chan).unwrap_or(vec![]) {
-                    request
-                        .respond
-                        .send(NotifyResponse {
-                            geo: Some(geo.clone()),
-                            last: Some(last),
-                        })
-                        .ok();
-                }
-            } else {
-                for request in clients.remove(&chan).unwrap_or(vec![]) {
-                    request
-                        .respond
-                        .send(NotifyResponse {
-                            geo: None,
-                            last: None,
-                        })
-                        .ok();
-                }
-            }
-
-            // We also need to receive new notification requests.
-            count += 1;
-            if count > 3 {
-                break;
-            }
-        }
-    }
-}
-
 fn main() {
     let (send, recv) = mpsc::channel();
-    let send = SendableSender {
+    let send = notifier::SendableSender {
         sender: Arc::new(Mutex::new(send)),
     };
 
@@ -488,7 +241,7 @@
                     rocket_contrib::databases::database_config("geohub", &rocket.config()).unwrap();
                 let url = dbconfig.url;
                 let conn = postgres::Connection::connect(url, postgres::TlsMode::None).unwrap();
-                thread::spawn(move || live_notifier_thread(recv, conn));
+                std::thread::spawn(move || notifier::live_notifier_thread(recv, conn));
                 Ok(rocket)
             },
         ))
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/notifier.rs	Thu Dec 03 00:01:48 2020 +0100
@@ -0,0 +1,120 @@
+
+use crate::db;
+use crate::ids;
+use crate::types;
+
+use std::collections::HashMap;
+use std::sync::{mpsc, Arc, Mutex};
+use std::time;
+use fallible_iterator::FallibleIterator;
+
+/// Request of a web client thread to the notifier thread.
+pub struct NotifyRequest {
+    pub client: String,
+    pub secret: Option<String>,
+    pub respond: SendableSender<NotifyResponse>,
+}
+
+/// Response from the notifier thread to a web client thread.
+pub struct NotifyResponse {
+    // The GeoJSON object containing the update and the `last` page token.
+    pub geo: Option<types::GeoJSON>,
+    pub last: Option<i32>,
+}
+
+/// A `Send` sender.
+#[derive(Clone)]
+pub struct SendableSender<T> {
+    pub sender: Arc<Mutex<mpsc::Sender<T>>>,
+}
+
+impl<T> SendableSender<T> {
+    pub fn send(&self, arg: T) -> Result<(), mpsc::SendError<T>> {
+        let s = self.sender.lock().unwrap();
+        s.send(arg)
+    }
+}
+
+/// Listen for notifications in the database and dispatch to waiting clients.
+pub fn live_notifier_thread(rx: mpsc::Receiver<NotifyRequest>, db: postgres::Connection) {
+    const TICK_MILLIS: u32 = 500;
+
+    let mut clients: HashMap<String, Vec<NotifyRequest>> = HashMap::new();
+
+    fn listen(db: &postgres::Connection, client: &str, secret: &str) -> postgres::Result<u64> {
+        let n = db
+            .execute(
+                &format!("LISTEN {}", ids::channel_name(client, secret).as_str()),
+                &[],
+            )
+            .unwrap();
+        Ok(n)
+    }
+    fn unlisten(db: &postgres::Connection, chan: &str) -> postgres::Result<u64> {
+        let n = db.execute(&format!("UNLISTEN {}", chan), &[]).unwrap();
+        Ok(n)
+    }
+
+    loop {
+        // This loop checks for new messages on rx, then checks for new database notifications, etc.
+
+        // Drain notification requests (clients asking to watch for notifications).
+        // We listen per client and secret to separate clients with different sessions (by secret).
+        loop {
+            if let Ok(nrq) = rx.try_recv() {
+                let secret = nrq.secret.as_ref().map(|s| s.as_str()).unwrap_or("");
+                let chan_name = ids::channel_name(nrq.client.as_str(), secret);
+                if !clients.contains_key(&chan_name) {
+                    listen(&db, &nrq.client, secret).ok();
+                }
+                clients.entry(chan_name).or_insert(vec![]).push(nrq);
+            } else {
+                break;
+            }
+        }
+
+        // Drain notifications from the database.
+        // Also provide updated rows to the client.
+        let notifications = db.notifications();
+        let mut iter = notifications.timeout_iter(time::Duration::new(0, TICK_MILLIS * 1_000_000));
+        let mut count = 0;
+
+        while let Ok(Some(notification)) = iter.next() {
+            let chan = notification.channel;
+            let (client, secret) = ids::client_secret(chan.as_str());
+            unlisten(&db, &chan).ok();
+
+            // These queries use the primary key index returning one row only and will be quite fast.
+            // Still: One query per client.
+            let rows = db::check_for_new_rows(&db, client, Some(secret), &None, &Some(1));
+            if let Some((geo, last)) = rows {
+                for request in clients.remove(&chan).unwrap_or(vec![]) {
+                    request
+                        .respond
+                        .send(NotifyResponse {
+                            geo: Some(geo.clone()),
+                            last: Some(last),
+                        })
+                        .ok();
+                }
+            } else {
+                for request in clients.remove(&chan).unwrap_or(vec![]) {
+                    request
+                        .respond
+                        .send(NotifyResponse {
+                            geo: None,
+                            last: None,
+                        })
+                        .ok();
+                }
+            }
+
+            // We also need to receive new notification requests.
+            count += 1;
+            if count > 3 {
+                break;
+            }
+        }
+    }
+}
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/types.rs	Thu Dec 03 00:01:48 2020 +0100
@@ -0,0 +1,52 @@
+
+/// Fetch geodata as JSON.
+///
+#[derive(serde::Serialize, Debug, Clone)]
+pub struct GeoProperties {
+    pub time: chrono::DateTime<chrono::Utc>,
+    pub altitude: Option<f64>,
+    pub speed: Option<f64>,
+}
+
+#[derive(serde::Serialize, Debug, Clone)]
+pub struct GeoGeometry {
+    #[serde(rename = "type")]
+    pub typ: String, // always "Point"
+    pub coordinates: Vec<f64>, // always [long, lat]
+}
+
+#[derive(serde::Serialize, Debug, Clone)]
+pub struct GeoFeature {
+    #[serde(rename = "type")]
+    pub typ: String, // always "Feature"
+    pub properties: GeoProperties,
+    pub geometry: GeoGeometry,
+}
+
+pub fn geofeature_from_row(
+    ts: chrono::DateTime<chrono::Utc>,
+    lat: Option<f64>,
+    long: Option<f64>,
+    spd: Option<f64>,
+    ele: Option<f64>,
+) -> GeoFeature {
+    GeoFeature {
+        typ: "Feature".into(),
+        properties: GeoProperties {
+            time: ts,
+            altitude: ele,
+            speed: spd,
+        },
+        geometry: GeoGeometry {
+            typ: "Point".into(),
+            coordinates: vec![long.unwrap_or(0.), lat.unwrap_or(0.)],
+        },
+    }
+}
+
+#[derive(serde::Serialize, Debug, Clone)]
+pub struct GeoJSON {
+    #[serde(rename = "type")]
+    pub typ: String, // always "FeatureCollection"
+    pub features: Vec<GeoFeature>,
+}