1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#![allow(clippy::doc_markdown)] // WebSocket is the name of the protocol
#![cfg(feature = "websocket")]
//! Easy and fast WebSockets for Kvarn.
//!
//! See [`response()`] for an example.
use crate::prelude::*;
use sha1::Digest;

pub use futures_util::{Sink, SinkExt, Stream, StreamExt};
pub use tokio_tungstenite;
pub use tokio_tungstenite::tungstenite;

static SEC_MAGIC_STRING: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

/// Form a response to a request on a WebSocket route.
///
/// Checks the HTTP version, `Connection` and `Upgrade` headers, handles the `Sec-WebSocket-Accept`
/// header, and makes sure this isn't an unallowed cross origin request.
///
/// # Examples
/// ```
/// use kvarn::prelude::*;
/// use kvarn::websocket::{SinkExt, StreamExt};
///
/// let mut extensions = Extensions::new();
///
/// extensions.add_prepare_single(
///     "/ws-ping",
///     prepare!(req, host, _path, _addr, {
///         kvarn::websocket::response(
///             req,
///             host,
///             response_pipe_fut!(response_pipe, _host, {
///                 let mut ws = kvarn::websocket::wrap(response_pipe).await.unwrap();
///                 while let Some(Ok(message)) = ws.next().await {
///                     let _ = ws.send(message).await;
///                 }
///             }),
///         ).await
///     }),
/// );
/// ```
pub async fn response(req: &FatRequest, host: &Host, future: ResponsePipeFuture) -> FatResponse {
    use base64::Engine;
    const DEFAULT_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
        &base64::alphabet::STANDARD,
        base64::engine::GeneralPurposeConfig::new().with_encode_padding(false),
    );

    if req.headers().get("connection").map_or(true, |conn| {
        conn.to_str().map_or(true, |s| {
            !s.split(',')
                .any(|s| s.trim().eq_ignore_ascii_case("upgrade"))
        })
    }) || req
        .headers()
        .get("upgrade")
        .map_or(true, |upg| upg != "websocket")
    {
        let mut response = default_error(StatusCode::UPGRADE_REQUIRED, Some(host), None).await;
        response
            .headers_mut()
            .insert("connection", HeaderValue::from_static("upgrade"));
        response
            .headers_mut()
            .insert("upgrade", HeaderValue::from_static("websocket"));
        return FatResponse::cache(response)
            .with_server_cache(comprash::ServerCachePreference::None);
    }
    if let Body::Http1(_) = req.body() {
        // allowed
    } else {
        return default_error_response(
            StatusCode::HTTP_VERSION_NOT_SUPPORTED,
            host,
            Some("You must use HTTP/1.1 for WebSocket requests"),
        )
        .await;
    }
    let key = if let Some(k) = req.headers().get("sec-websocket-key") {
        k
    } else {
        return default_error_response(
            StatusCode::BAD_REQUEST,
            host,
            Some("No Sec-WebSocket-Key header was sent"),
        )
        .await;
    };
    let mut hasher = sha1::Sha1::new();
    hasher.update(key);
    hasher.update(SEC_MAGIC_STRING);
    let hash = hasher.finalize();
    let mut bytes = BytesMut::with_capacity(28);
    // I have dug into the code and verified that the call to base64::encode_config_slice will fill
    // all 28 bytes.
    unsafe { bytes.set_len(28) };
    DEFAULT_ENGINE
        .encode_slice(hash, &mut bytes)
        .expect("base64 encoding failed");
    let response = Response::builder()
        .header(
            "sec-websocket-accept",
            HeaderValue::from_maybe_shared(bytes.freeze()).expect(
                "the base64 encoded sec-websocket-accept \
                response contains illegal header value bytes",
            ),
        )
        .header("upgrade", "websocket")
        .header("connection", "upgrade")
        .status(StatusCode::SWITCHING_PROTOCOLS)
        .body(Bytes::new())
        .expect("building the WebSocket response body failed");
    FatResponse::new(response, comprash::ServerCachePreference::None)
        .with_compress(comprash::CompressPreference::None)
        .with_future(future)
}

/// Error from WebSocket operations
#[derive(Debug)]
pub enum Error {
    /// WebSocket currently isn't supported for HTTP/3 nor HTTP/2.
    WebSocketUnsupported,
}
/// Variants of WebSocket streams.
#[derive(Debug)]
pub enum WSStream<'a> {
    /// HTTP/1 version
    /// WebTransport was introduced in HTTP/3 and WebSocket support is sketch in HTTP/2,
    /// so this is unlikely to get more versions
    Http1(&'a Arc<Mutex<Encryption>>),
}
impl<'a> AsyncRead for WSStream<'a> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match self.get_mut() {
            Self::Http1(s) => match s.try_lock() {
                Err(_) => Poll::Pending,
                Ok(mut s) => Pin::new(&mut *s).poll_read(cx, buf),
            },
        }
    }
}
impl<'a> AsyncWrite for WSStream<'a> {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        match self.get_mut() {
            Self::Http1(s) => match s.try_lock() {
                Err(_) => Poll::Pending,
                Ok(mut s) => Pin::new(&mut *s).poll_write(cx, buf),
            },
        }
    }
    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Poll::Ready(Ok(()))
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        match self.get_mut() {
            Self::Http1(s) => {
                if let Ok(mut s) = s.try_lock() {
                    Pin::new(&mut *s).poll_flush(cx)
                } else {
                    Poll::Pending
                }
            }
        }
    }
}

/// Get a [`tokio_tungstenite::WebSocketStream`] from the `pipe` given by [`response_pipe_fut!`].
///
/// # Examples
///
/// See [`response()`].
///
/// # Errors
///
/// Errors if `pipe` is [unsupported](Error::WebSocketUnsupported).
pub async fn wrap(
    pipe: &mut ResponseBodyPipe,
) -> Result<tokio_tungstenite::WebSocketStream<WSStream>, Error> {
    match pipe {
        ResponseBodyPipe::Http1(s) => Ok(tokio_tungstenite::WebSocketStream::from_raw_socket(
            WSStream::Http1(s),
            tungstenite::protocol::Role::Server,
            None,
        )
        .await),
        #[cfg(feature = "http2")]
        ResponseBodyPipe::Http2(_, _) => Err(Error::WebSocketUnsupported),
        #[cfg(feature = "http3")]
        ResponseBodyPipe::Http3(_) => Err(Error::WebSocketUnsupported),
    }
}