changeset 20:b60eb381eec8

BIG: Make live updates efficient by using only one database connection
author Lewin Bormann <lbo@spheniscida.de>
date Wed, 02 Dec 2020 21:04:16 +0100
parents 19910bc79239
children 643e00f58060
files src/main.rs
diffstat 1 files changed, 109 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/src/main.rs	Wed Dec 02 20:09:52 2020 +0100
+++ b/src/main.rs	Wed Dec 02 21:04:16 2020 +0100
@@ -1,5 +1,10 @@
 #![feature(proc_macro_hygiene, decl_macro)]
 
+use std::collections::HashMap;
+use std::sync::{mpsc, Arc, Mutex};
+use std::thread;
+use std::time;
+
 use postgres;
 use rocket;
 
@@ -162,10 +167,11 @@
 #[rocket::get("/geo/<name>/retrieve/live?<secret>&<last>&<timeout>")]
 fn retrieve_live(
     db: DBConn,
+    notify_manager: rocket::State<SendableSender<NotifyRequest>>,
     name: String,
     secret: Option<String>,
     last: Option<i32>,
-    timeout: Option<i32>,
+    timeout: Option<u64>,
 ) -> rocket_contrib::json::Json<LiveUpdate> {
     // Only if the client supplied a paging token should we check for new rows before. This is an
     // optimization.
@@ -179,20 +185,19 @@
         }
     }
 
-    // Otherwise we will wait for the next update.
-    //
-    let listen =
-        db.0.prepare_cached(format!("LISTEN geohubclient_update_{}", name).as_str())
-            .unwrap();
-    let unlisten =
-        db.0.prepare_cached(format!("UNLISTEN geohubclient_update_{}", name).as_str())
-            .unwrap();
+    let (send, recv) = mpsc::channel();
+    let send = SendableSender {
+        sender: Arc::new(Mutex::new(send)),
+    };
 
-    listen.execute(&[]).ok();
+    let req = NotifyRequest {
+        client: name.clone(),
+        respond: send,
+    };
+    notify_manager.send(req).unwrap();
 
-    let timeout = std::time::Duration::new(timeout.unwrap_or(30) as u64, 0);
-    if let Ok(_) = db.0.notifications().timeout_iter(timeout).next() {
-        unlisten.execute(&[]).ok();
+    if let Ok(response) = recv.recv_timeout(time::Duration::new(timeout.unwrap_or(30), 0)) {
+        eprintln!("Worker received response for {}", response.client);
         if let Some((geojson, last)) = check_for_new_rows(&db, &name, &secret, &last, &Some(1)) {
             return rocket_contrib::json::Json(LiveUpdate {
                 typ: "GeoHubUpdate".into(),
@@ -201,7 +206,6 @@
             });
         }
     }
-    unlisten.execute(&[]).ok();
     return rocket_contrib::json::Json(LiveUpdate {
         typ: "GeoHubUpdate".into(),
         last: last,
@@ -302,9 +306,100 @@
         .map_err(|e| rocket::response::status::NotFound(e.to_string()))
 }
 
+// Notify all waiters using just one DB connection.
+struct NotifyRequest {
+    client: String,
+    respond: SendableSender<NotifyResponse>,
+}
+
+struct NotifyResponse {
+    client: String,
+}
+
+#[derive(Clone)]
+struct SendableSender<T> {
+    sender: Arc<Mutex<mpsc::Sender<T>>>,
+}
+
+impl<T> SendableSender<T> {
+    fn send(&self, arg: T) -> Result<(), mpsc::SendError<T>> {
+        let s = self.sender.lock().unwrap();
+        s.send(arg)
+    }
+}
+
+fn live_notifier_thread(rx: mpsc::Receiver<NotifyRequest>, db: postgres::Connection) {
+    const TICK_MILLIS: u32 = 500;
+
+    let mut clients: HashMap<String, NotifyRequest> = HashMap::new();
+
+    fn listen(db: &postgres::Connection, client: &str) -> postgres::Result<u64> {
+        db.execute(&format!("LISTEN geohubclient_update_{}", client), &[])
+    }
+    fn unlisten(db: &postgres::Connection, client: &str) -> postgres::Result<u64> {
+        db.execute(&format!("UNLISTEN geohubclient_update_{}", client), &[])
+    }
+
+    eprintln!("Notification thread running.");
+    loop {
+        // This loop checks for new messages on rx, then checks for new database notifications, etc.
+
+        // Drain notification requests (clients asking to watch for notifications).
+        loop {
+            if let Ok(nrq) = rx.try_recv() {
+                if !clients.contains_key(&nrq.client) {
+                    listen(&db, &nrq.client).ok();
+                }
+                clients.insert(nrq.client.clone(), nrq);
+            } else {
+                break;
+            }
+        }
+
+        // Drain notifications from the database.
+        let notifications = db.notifications();
+        let mut iter = notifications.timeout_iter(time::Duration::new(0, TICK_MILLIS * 1_000_000));
+        let mut count = 0;
+        while let Ok(Some(notification)) = iter.next() {
+            let payload = notification.payload;
+            unlisten(&db, &payload).ok();
+
+            if let Some(request) = clients.remove(&payload) {
+                request
+                    .respond
+                    .send(NotifyResponse { client: payload })
+                    .ok();
+            }
+
+            // We also need to receive new notification requests.
+            count += 1;
+            if count > 3 {
+                break;
+            }
+        }
+    }
+}
+
 fn main() {
+    let (send, recv) = mpsc::channel();
+    let send = SendableSender {
+        sender: Arc::new(Mutex::new(send)),
+    };
+
     rocket::ignite()
         .attach(DBConn::fairing())
+        .manage(send)
+        .attach(rocket::fairing::AdHoc::on_attach(
+            "Database Notifications",
+            |rocket| {
+                let dbconfig =
+                    rocket_contrib::databases::database_config("geohub", &rocket.config()).unwrap();
+                let url = dbconfig.url;
+                let conn = postgres::Connection::connect(url, postgres::TlsMode::None).unwrap();
+                thread::spawn(move || live_notifier_thread(recv, conn));
+                Ok(rocket)
+            },
+        ))
         .mount(
             "/",
             rocket::routes![log, retrieve_json, retrieve_live, assets],