changeset 128:5ce43ba206c7

agac: Refactor the way downloads are done.
author Lewin Bormann <lbo@spheniscida.de>
date Mon, 26 Oct 2020 10:44:05 +0100
parents d8e96592590d
children 0c1e788bda0c
files async-google-apis-common/src/error.rs async-google-apis-common/src/http.rs
diffstat 2 files changed, 164 insertions(+), 94 deletions(-) [+]
line wrap: on
line diff
--- a/async-google-apis-common/src/error.rs	Mon Oct 26 09:21:25 2020 +0100
+++ b/async-google-apis-common/src/error.rs	Mon Oct 26 10:44:05 2020 +0100
@@ -1,8 +1,14 @@
 #[derive(Debug)]
 pub enum ApiError {
+    /// The API returned a non-OK HTTP response.
     HTTPResponseError(hyper::StatusCode, String),
+    /// Returned after being redirected more than five times.
+    HTTPTooManyRedirectsError,
+    /// E.g. a redirect was issued without a Location: header.
     RedirectError(String),
+    /// Invalid data was supplied to the library.
     InputDataError(String),
+    /// Data for download is available, but the caller hasn't supplied a destination to write to.
     DataAvailableError(String),
 }
 
--- a/async-google-apis-common/src/http.rs	Mon Oct 26 09:21:25 2020 +0100
+++ b/async-google-apis-common/src/http.rs	Mon Oct 26 10:44:05 2020 +0100
@@ -15,8 +15,8 @@
 pub struct EmptyResponse {}
 
 /// Result of a method that can (but doesn't always) download data.
-#[derive(Debug)]
-pub enum DownloadResponse<T: DeserializeOwned + std::fmt::Debug> {
+#[derive(Debug, PartialEq)]
+pub enum DownloadResult<T: DeserializeOwned + std::fmt::Debug> {
     /// Downloaded data has been written to the supplied Writer.
     Downloaded,
     /// A structured response has been returned.
@@ -145,103 +145,167 @@
     }
 }
 
+/// An ongoing download.
+///
+/// Note that this does not necessarily result in a download. It is returned by all API methods
+/// that are capable of downloading data. Whether a download takes place is determined by the
+/// `Content-Type` sent by the server; frequently, the parameters sent in the request determine
+/// whether the server starts a download (`Content-Type: whatever`) or sends a response
+/// (`Content-Type: application/json`).
+pub struct Download<'a, Request, Response> {
+    cl: &'a TlsClient,
+    http_method: String,
+    uri: hyper::Uri,
+    rq: Option<&'a Request>,
+    headers: Vec<(hyper::header::HeaderName, String)>,
+
+    _marker: std::marker::PhantomData<Response>,
+}
+
+impl<'a, Request: Serialize + std::fmt::Debug, Response: DeserializeOwned + std::fmt::Debug>
+    Download<'a, Request, Response>
+{
+    /// Trivial adapter for `download()`: Store downloaded data into a `Vec<u8>`.
+    pub async fn do_it_to_buf(&mut self, buf: &mut Vec<u8>) -> Result<DownloadResult<Response>> {
+        self.do_it(Some(buf)).await
+    }
+
+    /// Run the actual download, streaming the response into the supplied `dst`. If the server
+    /// responded with a `Response` object, no download is started; the response is wrapped in the
+    /// `DownloadResult<Response>` object.
+    ///
+    /// Whether a download takes place or you receive a structured `Response` (i.e. a JSON object)
+    /// depends on the `Content-Type` sent by the server. It is an error to attempt a download
+    /// without specifying `dst`. Often, whether a download takes place is influenced by the
+    /// request parameters. For example, `alt = media` is frequently used in Google APIs to
+    /// indicate that a download is expected.
+    pub async fn do_it(
+        &mut self,
+        dst: Option<&mut (dyn tokio::io::AsyncWrite + std::marker::Unpin)>,
+    ) -> Result<DownloadResult<Response>> {
+        use std::str::FromStr;
+
+        let mut http_response;
+        let mut n_redirects = 0;
+        let mut uri = self.uri.clone();
+
+        // Follow redirects.
+        loop {
+            let mut reqb = hyper::Request::builder()
+                .uri(&uri)
+                .method(self.http_method.as_str());
+            for (k, v) in self.headers.iter() {
+                reqb = reqb.header(k, v);
+            }
+
+            let body;
+            if let Some(rq) = self.rq.take() {
+                body = hyper::Body::from(
+                    serde_json::to_string(&rq).context(format!("{:?}", self.rq))?,
+                );
+            } else {
+                body = hyper::Body::from("");
+            }
+
+            let http_request = reqb.body(body)?;
+            debug!(
+                "do_download: Redirect {}, Launching HTTP request: {:?}",
+                n_redirects, http_request
+            );
+
+            http_response = Some(self.cl.request(http_request).await?);
+            let status = http_response.as_ref().unwrap().status();
+            debug!(
+                "do_download: Redirect {}, HTTP response with status {} received: {:?}",
+                n_redirects, status, http_response
+            );
+
+            // Server returns data - either download or structured response (JSON).
+            if status.is_success() {
+                let headers = http_response.as_ref().unwrap().headers();
+
+                // Check if an object was returned.
+                if let Some(ct) = headers.get(hyper::header::CONTENT_TYPE) {
+                    if ct.to_str()?.contains("application/json") {
+                        let response_body =
+                            hyper::body::to_bytes(http_response.unwrap().into_body()).await?;
+                        return serde_json::from_reader(response_body.as_ref())
+                            .map_err(|e| anyhow::Error::from(e).context(body_to_str(response_body)))
+                            .map(DownloadResult::Response);
+                    }
+                }
+
+                if let Some(dst) = dst {
+                    use tokio::io::AsyncWriteExt;
+                    let mut response_body = http_response.unwrap().into_body();
+                    while let Some(chunk) = tokio::stream::StreamExt::next(&mut response_body).await
+                    {
+                        let chunk = chunk?;
+                        // Chunks often contain just a few kilobytes.
+                        // info!("received chunk with size {}", chunk.as_ref().len());
+                        dst.write(chunk.as_ref()).await?;
+                    }
+                    return Ok(DownloadResult::Downloaded);
+                } else {
+                    return Err(ApiError::DataAvailableError(format!(
+                        "No `dst` was supplied to download data to. Content-Type: {:?}",
+                        headers.get(hyper::header::CONTENT_TYPE)
+                    ))
+                    .into());
+                }
+
+            // Server redirects us.
+            } else if status.is_redirection() {
+                n_redirects += 1;
+                let new_location = http_response
+                    .as_ref()
+                    .unwrap()
+                    .headers()
+                    .get(hyper::header::LOCATION);
+                if new_location.is_none() {
+                    return Err(ApiError::RedirectError(format!(
+                        "Redirect doesn't contain a Location: header"
+                    ))
+                    .into());
+                }
+                uri = hyper::Uri::from_str(new_location.unwrap().to_str()?)?;
+                continue;
+            } else if !status.is_success() {
+                return Err(ApiError::HTTPResponseError(
+                    status,
+                    body_to_str(hyper::body::to_bytes(http_response.unwrap().into_body()).await?),
+                )
+                .into());
+            }
+
+            // Too many redirects.
+            if n_redirects > 5 {
+                return Err(ApiError::HTTPTooManyRedirectsError.into());
+            }
+        }
+    }
+}
+
 pub async fn do_download<
+    'a,
     Req: Serialize + std::fmt::Debug,
     Resp: DeserializeOwned + std::fmt::Debug,
 >(
-    cl: &TlsClient,
+    cl: &'a TlsClient,
     path: &str,
-    headers: &[(hyper::header::HeaderName, String)],
-    http_method: &str,
-    rq: Option<Req>,
-    dst: Option<&mut (dyn tokio::io::AsyncWrite + std::marker::Unpin)>,
-) -> Result<DownloadResponse<Resp>> {
-    let mut path = path.to_string();
-    let mut http_response;
-    let mut i = 0;
-
-    // Follow redirects.
-    loop {
-        let mut reqb = hyper::Request::builder().uri(&path).method(http_method);
-        for (k, v) in headers {
-            reqb = reqb.header(k, v);
-        }
-        let body_str = serde_json::to_string(&rq).context(format!("{:?}", rq))?;
-        let body;
-        if body_str == "null" {
-            body = hyper::Body::from("");
-        } else {
-            body = hyper::Body::from(body_str);
-        }
-
-        let http_request = reqb.body(body)?;
-        debug!(
-            "do_download: Redirect {}, Launching HTTP request: {:?}",
-            i, http_request
-        );
-
-        http_response = Some(cl.request(http_request).await?);
-        let status = http_response.as_ref().unwrap().status();
-        debug!(
-            "do_download: Redirect {}, HTTP response with status {} received: {:?}",
-            i, status, http_response
-        );
-
-        if status.is_success() {
-            break;
-        } else if status.is_redirection() {
-            i += 1;
-            let new_location = http_response
-                .as_ref()
-                .unwrap()
-                .headers()
-                .get(hyper::header::LOCATION);
-            if new_location.is_none() {
-                return Err(ApiError::RedirectError(format!(
-                    "Redirect doesn't contain a Location: header"
-                ))
-                .into());
-            }
-            path = new_location.unwrap().to_str()?.to_string();
-            continue;
-        } else if !status.is_success() {
-            return Err(ApiError::HTTPResponseError(
-                status,
-                body_to_str(hyper::body::to_bytes(http_response.unwrap().into_body()).await?),
-            )
-            .into());
-        }
-    }
-
-    let headers = http_response.as_ref().unwrap().headers();
-    if let Some(ct) = headers.get(hyper::header::CONTENT_TYPE) {
-        if ct.to_str()?.contains("application/json") {
-            let status = http_response.as_ref().unwrap().status();
-            let response_body = hyper::body::to_bytes(http_response.unwrap().into_body()).await?;
-
-            return if !status.is_success() {
-                Err(ApiError::HTTPResponseError(status, body_to_str(response_body)).into())
-            } else {
-                serde_json::from_reader(response_body.as_ref())
-                    .map_err(|e| anyhow::Error::from(e).context(body_to_str(response_body)))
-                    .map(DownloadResponse::Response)
-            };
-        }
-    }
-
-    use tokio::io::AsyncWriteExt;
-    let mut response_body = http_response.unwrap().into_body();
-    if let Some(dst) = dst {
-        while let Some(chunk) = tokio::stream::StreamExt::next(&mut response_body).await {
-            dst.write(chunk?.as_ref()).await?;
-        }
-        Ok(DownloadResponse::Downloaded)
-    } else {
-        Err(ApiError::DataAvailableError(
-            "do_download: No destination for downloaded data was specified".into(),
-        )
-        .into())
-    }
+    headers: Vec<(hyper::header::HeaderName, String)>,
+    http_method: String,
+    rq: Option<&'a Req>,
+) -> Result<Download<'a, Req, Resp>> {
+    use std::str::FromStr;
+    Ok(Download {
+        cl: cl,
+        http_method: http_method,
+        uri: hyper::Uri::from_str(path)?,
+        rq: rq,
+        headers: headers,
+        _marker: Default::default(),
+    })
 }
 
 /// A resumable upload in progress, useful for sending large objects.