changeset 11:593822c857b7 default tip

Enable true correlation in SocketWrapper This allows applications to send a batch of messages at once and later still receive the correct responses to every request.
author Lewin Bormann <lbo@spheniscida.de>
date Sun, 25 Sep 2016 15:19:20 +0200
parents b99a9821115c
children
files src/main/java/net/borgac/clusterrpc/client/ClientChannel.java src/main/java/net/borgac/clusterrpc/client/OutgoingFilter.java src/main/java/net/borgac/clusterrpc/client/SocketWrapper.java src/test/java/net/borgac/clusterrpc/client/SocketWrapperTest.java
diffstat 4 files changed, 152 insertions(+), 41 deletions(-) [+]
line wrap: on
line diff
--- a/src/main/java/net/borgac/clusterrpc/client/ClientChannel.java	Sun Sep 25 15:18:18 2016 +0200
+++ b/src/main/java/net/borgac/clusterrpc/client/ClientChannel.java	Sun Sep 25 15:19:20 2016 +0200
@@ -20,7 +20,7 @@
 
     // Simple seed is good enough here.
     private static final Random ID_GENERATOR = new Random(
-            Instant.now().getEpochSecond());
+            Instant.now().getEpochSecond() + Instant.now().getNano());
 
     private Logger logger;
 
@@ -99,22 +99,48 @@
         return false;
     }
 
-    boolean send(Rpc.RPCRequest request) {
+    /**
+     * Send an RPC request.
+     *
+     * The returned object has to be used in order to receive the response.
+     * While not thread-safe, it allows for sending multiple messages at once
+     * and then receive all responses.
+     *
+     * @param request
+     * @return
+     * @throws RpcException
+     */
+    SocketWrapper.RequestID send(Rpc.RPCRequest request) throws RpcException {
         byte[] serialized = request.toByteArray();
 
-        return sock.send(serialized);
+        SocketWrapper.RequestID id = sock.send(serialized);
+
+        if (id == null) {
+            throw new RpcException(RpcException.Reason.IO_ERROR, "Could not send message over socket!");
+        } else {
+            return id;
+        }
     }
 
-    Rpc.RPCResponse receive() {
-        byte[] response = sock.recv();
+    /**
+     * Receive a response to a previous request.
+     *
+     * @param id The token returned by send()
+     * @return
+     * @throws RpcException
+     */
+    Rpc.RPCResponse receive(SocketWrapper.RequestID id) throws RpcException {
+        byte[] response = sock.recv(id);
         Rpc.RPCResponse parsed = null;
 
         try {
             parsed = Rpc.RPCResponse.parseFrom(response);
         } catch (InvalidProtocolBufferException e) {
             logger.log(Logger.Loglevel.ERROR, "Exception parsing RPCResponse:", e.toString());
+            throw new RpcException(e, RpcException.Reason.DECODING_ERROR, "Parsing RPCResponse failed");
         } catch (Exception e) {
             logger.log(Logger.Loglevel.FATAL, "Unhandled exception; this is a bug:", e.toString());
+            throw new RpcException(e, RpcException.Reason.UNKNOWN, "Caught unknown exception from parseFrom() method");
         }
 
         return parsed;
--- a/src/main/java/net/borgac/clusterrpc/client/OutgoingFilter.java	Sun Sep 25 15:18:18 2016 +0200
+++ b/src/main/java/net/borgac/clusterrpc/client/OutgoingFilter.java	Sun Sep 25 15:19:20 2016 +0200
@@ -22,5 +22,5 @@
      */
     void setInner(OutgoingFilter inner);
 
-    Response go(Request request);
+    Response go(Request request) throws RpcException;
 }
--- a/src/main/java/net/borgac/clusterrpc/client/SocketWrapper.java	Sun Sep 25 15:18:18 2016 +0200
+++ b/src/main/java/net/borgac/clusterrpc/client/SocketWrapper.java	Sun Sep 25 15:19:20 2016 +0200
@@ -3,6 +3,7 @@
 import java.io.Closeable;
 import java.time.Instant;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Random;
 import org.zeromq.ZMQ;
 import org.zeromq.ZMsg;
@@ -14,6 +15,8 @@
  * ZMTP 3 (yet), so we handle the REQ_RELAXED/REQ_CORRELATE functionality
  * ourselves. Thus, the inner socket is a DEALER socket, and not a REQ socket.
  *
+ * This class is not threadsafe.
+ *
  * @author lbo
  */
 class SocketWrapper implements Closeable {
@@ -24,12 +27,13 @@
 
     private final ZMQ.Socket sock;
     // For correlation of requests
-    private byte[] outstandingRequestId;
+    private final HashMap<RequestID, byte[]> outstandingRequests;
+
     private final Logger logger;
 
     SocketWrapper(Logger l) {
         this.sock = GlobalContextProvider.clientSocket();
-        this.outstandingRequestId = null;
+        this.outstandingRequests = new HashMap<>();
         this.logger = l;
 
         byte[] clientId = new byte[5];
@@ -55,59 +59,104 @@
         sock.close();
     }
 
-    boolean send(String payload) {
+    /**
+     * Returns an (opaque) Request ID that should be used in recv()
+     *
+     * @param payload
+     * @return
+     */
+    RequestID send(String payload) {
         return send(payload.getBytes());
     }
 
-    boolean send(byte[] payload) {
+    RequestID send(byte[] payload) {
         ZMsg message = new ZMsg();
-
-        // If a request that we haven't received yet, we will simulate
-        // REQ_RELAXED/REQ_CORRELATE and will just ignore the previous request.
-        byte[] requestId = new byte[5];
-        ID_GENERATOR.nextBytes(requestId);
-
-        outstandingRequestId = requestId;
+        RequestID id = new RequestID();
 
         // The wire format is
         // [request ID, empty frame, payload]
         // ...simulating what a REQ socket would send.
-        message.add(requestId);
+        message.add(id.getSerializedId());
         message.add(new byte[0]);
         message.add(payload);
 
         assert message.size() == EXPECTED_REQUEST_SIZE;
 
-        return message.send(sock, true);
+        if (message.send(sock, true)) {
+            return id;
+        } else {
+            return null;
+        }
     }
 
-    byte[] recv() {
-        ZMsg message;
-        byte[] response = null;
+    byte[] recv(RequestID id) {
+        do {
+            // Check if our response is already here.
+            if (outstandingRequests.containsKey(id)) {
+                return outstandingRequests.remove(id);
+            }
 
-        do {
-            message = ZMsg.recvMsg(sock);
+            ZMsg message = ZMsg.recvMsg(sock);
 
             if (message.size() != EXPECTED_RESPONSE_SIZE) {
-                logger.log(Logger.Loglevel.ERROR, "Received response with bad length:", message.size());
+                logger.log(Logger.Loglevel.ERROR, "Received response with bad length:",
+                        message.size());
             }
 
             // Check if the response is to our last request; otherwise throw away
             byte[] requestId = message.pop().getData();
             // empty frame
             message.pop().getData();
-            response = message.pop().getData();
+            byte[] response = message.pop().getData();
 
-            if (Arrays.equals(requestId, outstandingRequestId)) {
-                break;
+            if (Arrays.equals(requestId, id.getSerializedId())) {
+                return response;
             } else {
-                logger.log(Logger.Loglevel.WARNING, "Received response with unknown request ID:",
-                        Arrays.toString(requestId), "vs", Arrays.toString(outstandingRequestId));
+                outstandingRequests.put(new RequestID(requestId), response);
             }
         } while (true);
+    }
 
-        outstandingRequestId = null;
+    static class RequestID {
+
+        private static final Random ID_GENERATOR = new Random(Instant.now().getEpochSecond()
+                + Instant.now().getNano());
+        private final byte[] id;
+
+        RequestID() {
+            this.id = new byte[6];
+            ID_GENERATOR.nextBytes(this.id);
+        }
+
+        RequestID(byte[] id) {
+            this.id = id;
+        }
+
+        byte[] getSerializedId() {
+            return id;
+        }
 
-        return response;
+        @Override
+        public int hashCode() {
+            return id.hashCode();
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) {
+                return true;
+            }
+            if (obj == null) {
+                return false;
+            }
+            if (getClass() != obj.getClass()) {
+                return false;
+            }
+            final RequestID other = (RequestID) obj;
+            if (!Arrays.equals(this.id, other.id)) {
+                return false;
+            }
+            return true;
+        }
     }
 }
--- a/src/test/java/net/borgac/clusterrpc/client/SocketWrapperTest.java	Sun Sep 25 15:18:18 2016 +0200
+++ b/src/test/java/net/borgac/clusterrpc/client/SocketWrapperTest.java	Sun Sep 25 15:19:20 2016 +0200
@@ -64,7 +64,7 @@
 
         byte[] requestId = msgs.pop().getData();
 
-        Assert.assertEquals(5, requestId.length);
+        Assert.assertEquals(6, requestId.length);
 
         byte[] empty = msgs.pop().getData();
 
@@ -75,10 +75,13 @@
         Assert.assertArrayEquals("TestMessage".getBytes(), payload);
     }
 
-    @Test
-    public void testSendReceive() {
-        sw.send("TestMessage");
-
+    /**
+     * Echoes responseMessage or the client's payload if responseMessage is
+     * null.
+     *
+     * @param responseMessage
+     */
+    void serverEcho(String responseMessage) {
         ZMsg msgs = ZMsg.recvMsg(server);
         Assert.assertEquals(4, msgs.size());
 
@@ -86,18 +89,30 @@
         response.add(msgs.pop());
         response.add(msgs.pop());
         response.add(msgs.pop());
-        response.add("ServerResponse");
+
+        if (responseMessage != null) {
+            response.add(responseMessage);
+        } else {
+            response.add(msgs.pop());
+        }
 
         response.send(server);
+    }
 
-        byte[] fromServer = sw.recv();
+    @Test
+    public void testSendReceive() {
+        SocketWrapper.RequestID id = sw.send("TestMessage");
+
+        serverEcho("ServerResponse");
+
+        byte[] fromServer = sw.recv(id);
 
         Assert.assertArrayEquals("ServerResponse".getBytes(), fromServer);
     }
 
     @Test
     public void testSendReceiveWithBadRequestId() {
-        sw.send("TestMessage");
+        SocketWrapper.RequestID id = sw.send("TestMessage");
 
         byte[] clientId, requestId;
 
@@ -124,7 +139,28 @@
         response.send(server);
 
         // Expectation: Client discards first response.
-        byte[] fromServer = sw.recv();
+        byte[] fromServer = sw.recv(id);
         Assert.assertArrayEquals("GoodServerResponse".getBytes(), fromServer);
     }
+
+    @Test
+    public void testCorrelation() {
+        SocketWrapper.RequestID id1 = sw.send("Test1");
+        serverEcho(null);
+
+        SocketWrapper.RequestID id2 = sw.send("Test2");
+        serverEcho(null);
+
+        SocketWrapper.RequestID id3 = sw.send("Test3");
+        serverEcho(null);
+
+        // Expected: Although sent in reverse order, the requests should be correlated correctly.
+        byte[] response2 = sw.recv(id2);
+        byte[] response3 = sw.recv(id3);
+        byte[] response1 = sw.recv(id1);
+
+        Assert.assertArrayEquals(response3, "Test3".getBytes());
+        Assert.assertArrayEquals(response2, "Test2".getBytes());
+        Assert.assertArrayEquals(response1, "Test1".getBytes());
+    }
 }