changeset 27:524180b9fa0f

Make live updates more efficient
author Lewin Bormann <lbo@spheniscida.de>
date Wed, 02 Dec 2020 23:14:01 +0100
parents d3a30e219b9b
children bf2b762e8584
files src/main.rs
diffstat 1 files changed, 94 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/src/main.rs	Wed Dec 02 22:08:44 2020 +0100
+++ b/src/main.rs	Wed Dec 02 23:14:01 2020 +0100
@@ -41,21 +41,21 @@
 
 /// Fetch geodata as JSON.
 ///
-#[derive(serde::Serialize, Debug)]
+#[derive(serde::Serialize, Debug, Clone)]
 struct GeoProperties {
     time: chrono::DateTime<chrono::Utc>,
     altitude: Option<f64>,
     speed: Option<f64>,
 }
 
-#[derive(serde::Serialize, Debug)]
+#[derive(serde::Serialize, Debug, Clone)]
 struct GeoGeometry {
     #[serde(rename = "type")]
     typ: String, // always "Point"
     coordinates: Vec<f64>, // always [long, lat]
 }
 
-#[derive(serde::Serialize, Debug)]
+#[derive(serde::Serialize, Debug, Clone)]
 struct GeoFeature {
     #[serde(rename = "type")]
     typ: String, // always "Feature"
@@ -84,7 +84,7 @@
     }
 }
 
-#[derive(serde::Serialize, Debug)]
+#[derive(serde::Serialize, Debug, Clone)]
 struct GeoJSON {
     #[serde(rename = "type")]
     typ: String, // always "FeatureCollection"
@@ -97,13 +97,14 @@
     typ: String, // always "GeoHubUpdate"
     last: Option<i32>, // page token -- send in next request!
     geo: Option<GeoJSON>,
+    error: Option<String>,
 }
 
 /// Queries for at most `limit` rows since entry ID `last`.
 fn check_for_new_rows(
     db: &postgres::Connection,
-    name: &String,
-    secret: &Option<String>,
+    name: &str,
+    secret: Option<&str>,
     last: &Option<i32>,
     limit: &Option<i64>,
 ) -> Option<(GeoJSON, i32)> {
@@ -171,17 +172,25 @@
     last: Option<i32>,
     limit: Option<i64>,
 ) -> rocket_contrib::json::Json<LiveUpdate> {
-    if let Some((geojson, newlast)) = check_for_new_rows(&db.0, &name, &secret, &last, &limit) {
+    if let Some((geojson, newlast)) = check_for_new_rows(
+        &db.0,
+        &name,
+        secret.as_ref().map(|s| s.as_str()),
+        &last,
+        &limit,
+    ) {
         return rocket_contrib::json::Json(LiveUpdate {
             typ: "GeoHubUpdate".into(),
             last: Some(newlast),
             geo: Some(geojson),
+            error: None,
         });
     }
     return rocket_contrib::json::Json(LiveUpdate {
         typ: "GeoHubUpdate".into(),
         last: last,
         geo: None,
+        error: Some("No new rows returned".into()),
     });
 }
 /// Wait for an update.
@@ -193,6 +202,14 @@
     secret: Option<String>,
     timeout: Option<u64>,
 ) -> rocket_contrib::json::Json<LiveUpdate> {
+    if !name_and_secret_acceptable(name.as_str(), secret.as_ref().map(|s| s.as_str())) {
+        return rocket_contrib::json::Json(LiveUpdate {
+            typ: "GeoHubUpdate".into(),
+            last: None,
+            geo: None,
+            error: Some("You have supplied an invalid secret or name. Both must be ASCII alphanumeric strings.".into()),
+        });
+    }
     let (send, recv) = mpsc::channel();
     let send = SendableSender {
         sender: Arc::new(Mutex::new(send)),
@@ -211,12 +228,14 @@
             typ: "GeoHubUpdate".into(),
             last: response.last,
             geo: response.geo,
+            error: None,
         });
     }
     return rocket_contrib::json::Json(LiveUpdate {
         typ: "GeoHubUpdate".into(),
         last: None,
         geo: None,
+        error: Some("No new rows returned".into()),
     });
 }
 
@@ -286,7 +305,8 @@
     s: Option<f64>,
     ele: Option<f64>,
 ) -> rocket::http::Status {
-    if name.chars().any(|c| !c.is_alphanumeric()) {
+    // Check that secret and client name are legal.
+    if !name_and_secret_acceptable(name.as_str(), secret.as_ref().map(|s| s.as_str())) {
         return rocket::http::Status::NotAcceptable;
     }
     let mut ts = chrono::Utc::now();
@@ -294,9 +314,13 @@
         ts = flexible_timestamp_parse(time).unwrap_or(ts);
     }
     let stmt = db.0.prepare_cached("INSERT INTO geohub.geodata (client, lat, long, spd, t, ele, secret) VALUES ($1, $2, $3, $4, $5, $6, public.digest($7, 'sha256'))").unwrap();
-    let notify =
-        db.0.prepare_cached(format!("NOTIFY geohubclient_update_{}, '{}'", name, name).as_str())
-            .unwrap();
+    let channel = format!(
+        "NOTIFY {}, '{}'",
+        channel_name(name.as_str(), secret.as_ref().unwrap_or(&"".into())),
+        name
+    );
+    eprintln!("notifying channel {}", channel);
+    let notify = db.0.prepare_cached(channel.as_str()).unwrap();
     stmt.execute(&[&name, &lat, &longitude, &s, &ts, &ele, &secret])
         .unwrap();
     notify.execute(&[]).unwrap();
@@ -339,16 +363,45 @@
     }
 }
 
+fn name_and_secret_acceptable(client: &str, secret: Option<&str>) -> bool {
+    !(client.chars().any(|c| !c.is_ascii_alphanumeric())
+        || secret
+            .unwrap_or("")
+            .chars()
+            .any(|c| !c.is_ascii_alphanumeric()))
+}
+
+fn channel_name(client: &str, secret: &str) -> String {
+    // The log handler should check this.
+    assert!(secret.find('_').is_none());
+    format!("geohubclient_update_{}_{}", client, secret)
+}
+fn client_secret(channel_name: &str) -> (&str, &str) {
+    // Channel name is like geohubclient_update_<client>_<secret>
+    let parts = channel_name.split('_').collect::<Vec<&str>>();
+    assert!(parts.len() == 4);
+    return (parts[2], parts[3]);
+}
+
 fn live_notifier_thread(rx: mpsc::Receiver<NotifyRequest>, db: postgres::Connection) {
     const TICK_MILLIS: u32 = 500;
 
     let mut clients: HashMap<String, Vec<NotifyRequest>> = HashMap::new();
 
-    fn listen(db: &postgres::Connection, client: &str) -> postgres::Result<u64> {
-        db.execute(&format!("LISTEN geohubclient_update_{}", client), &[])
+    fn listen(db: &postgres::Connection, client: &str, secret: &str) -> postgres::Result<u64> {
+        eprintln!("listening on channel {}", channel_name(client, secret));
+        let n = db
+            .execute(
+                &format!("LISTEN {}", channel_name(client, secret).as_str()),
+                &[],
+            )
+            .unwrap();
+        Ok(n)
     }
-    fn unlisten(db: &postgres::Connection, client: &str) -> postgres::Result<u64> {
-        db.execute(&format!("UNLISTEN geohubclient_update_{}", client), &[])
+    fn unlisten(db: &postgres::Connection, chan: &str) -> postgres::Result<u64> {
+        eprintln!("unlistening on channel {}", chan);
+        let n = db.execute(&format!("UNLISTEN {}", chan), &[]).unwrap();
+        Ok(n)
     }
 
     eprintln!("Notification thread running.");
@@ -356,15 +409,15 @@
         // This loop checks for new messages on rx, then checks for new database notifications, etc.
 
         // Drain notification requests (clients asking to watch for notifications).
+        // We listen per client and secret to separate clients with different sessions (by secret).
         loop {
             if let Ok(nrq) = rx.try_recv() {
-                if !clients.contains_key(&nrq.client) {
-                    listen(&db, &nrq.client).ok();
+                let secret = nrq.secret.as_ref().map(|s| s.as_str()).unwrap_or("");
+                let chan_name = channel_name(nrq.client.as_str(), secret);
+                if !clients.contains_key(&chan_name) {
+                    listen(&db, &nrq.client, secret).ok();
                 }
-                clients
-                    .entry(nrq.client.clone())
-                    .or_insert(vec![])
-                    .push(nrq);
+                clients.entry(chan_name).or_insert(vec![]).push(nrq);
             } else {
                 break;
             }
@@ -375,29 +428,37 @@
         let notifications = db.notifications();
         let mut iter = notifications.timeout_iter(time::Duration::new(0, TICK_MILLIS * 1_000_000));
         let mut count = 0;
-        while let Ok(Some(notification)) = iter.next() {
-            let payload = notification.payload;
-            unlisten(&db, &payload).ok();
 
-            // One query per listening client as secrets may be different.
+        while let Ok(Some(notification)) = iter.next() {
+            let chan = notification.channel;
+            let (client, secret) = client_secret(chan.as_str());
+            unlisten(&db, &chan).ok();
+            eprintln!(
+                "looking at channel {} client {} secret {}",
+                chan, client, secret
+            );
+
             // These queries use the primary key index returning one row only and will be quite fast.
-            for request in clients.remove(&payload).unwrap_or(vec![]) {
-                if let Some((geo, last)) =
-                    check_for_new_rows(&db, &payload, &request.secret, &None, &Some(1))
-                {
+            // Still: One query per client.
+            let rows = check_for_new_rows(&db, client, Some(secret), &None, &Some(1));
+            println!("{:?}", rows);
+            if let Some((geo, last)) = rows {
+                for request in clients.remove(&chan).unwrap_or(vec![]) {
                     request
                         .respond
                         .send(NotifyResponse {
-                            client: payload.clone(),
-                            geo: Some(geo),
+                            client: client.into(),
+                            geo: Some(geo.clone()),
                             last: Some(last),
                         })
                         .ok();
-                } else {
+                }
+            } else {
+                for request in clients.remove(&chan).unwrap_or(vec![]) {
                     request
                         .respond
                         .send(NotifyResponse {
-                            client: payload.clone(),
+                            client: client.into(),
                             geo: None,
                             last: None,
                         })