#![allow(clippy::doc_markdown)] #![cfg(feature = "websocket")]
use crate::prelude::*;
use sha1::Digest;
pub use futures_util::{Sink, SinkExt, Stream, StreamExt};
pub use tokio_tungstenite;
pub use tokio_tungstenite::tungstenite;
static SEC_MAGIC_STRING: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
pub async fn response(req: &FatRequest, host: &Host, future: ResponsePipeFuture) -> FatResponse {
use base64::Engine;
const DEFAULT_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
&base64::alphabet::STANDARD,
base64::engine::GeneralPurposeConfig::new().with_encode_padding(false),
);
if req.headers().get("connection").map_or(true, |conn| {
conn.to_str().map_or(true, |s| {
!s.split(',')
.any(|s| s.trim().eq_ignore_ascii_case("upgrade"))
})
}) || req
.headers()
.get("upgrade")
.map_or(true, |upg| upg != "websocket")
{
let mut response = default_error(StatusCode::UPGRADE_REQUIRED, Some(host), None).await;
response
.headers_mut()
.insert("connection", HeaderValue::from_static("upgrade"));
response
.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
return FatResponse::cache(response)
.with_server_cache(comprash::ServerCachePreference::None);
}
if let Body::Http1(_) = req.body() {
} else {
return default_error_response(
StatusCode::HTTP_VERSION_NOT_SUPPORTED,
host,
Some("You must use HTTP/1.1 for WebSocket requests"),
)
.await;
}
let key = if let Some(k) = req.headers().get("sec-websocket-key") {
k
} else {
return default_error_response(
StatusCode::BAD_REQUEST,
host,
Some("No Sec-WebSocket-Key header was sent"),
)
.await;
};
let mut hasher = sha1::Sha1::new();
hasher.update(key);
hasher.update(SEC_MAGIC_STRING);
let hash = hasher.finalize();
let mut bytes = BytesMut::with_capacity(28);
unsafe { bytes.set_len(28) };
DEFAULT_ENGINE
.encode_slice(hash, &mut bytes)
.expect("base64 encoding failed");
let response = Response::builder()
.header(
"sec-websocket-accept",
HeaderValue::from_maybe_shared(bytes.freeze()).expect(
"the base64 encoded sec-websocket-accept \
response contains illegal header value bytes",
),
)
.header("upgrade", "websocket")
.header("connection", "upgrade")
.status(StatusCode::SWITCHING_PROTOCOLS)
.body(Bytes::new())
.expect("building the WebSocket response body failed");
FatResponse::new(response, comprash::ServerCachePreference::None)
.with_compress(comprash::CompressPreference::None)
.with_future(future)
}
#[derive(Debug)]
pub enum Error {
WebSocketUnsupported,
}
#[derive(Debug)]
pub enum WSStream<'a> {
Http1(&'a Arc<Mutex<Encryption>>),
}
impl<'a> AsyncRead for WSStream<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Http1(s) => match s.try_lock() {
Err(_) => Poll::Pending,
Ok(mut s) => Pin::new(&mut *s).poll_read(cx, buf),
},
}
}
}
impl<'a> AsyncWrite for WSStream<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.get_mut() {
Self::Http1(s) => match s.try_lock() {
Err(_) => Poll::Pending,
Ok(mut s) => Pin::new(&mut *s).poll_write(cx, buf),
},
}
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.get_mut() {
Self::Http1(s) => {
if let Ok(mut s) = s.try_lock() {
Pin::new(&mut *s).poll_flush(cx)
} else {
Poll::Pending
}
}
}
}
}
pub async fn wrap(
pipe: &mut ResponseBodyPipe,
) -> Result<tokio_tungstenite::WebSocketStream<WSStream>, Error> {
match pipe {
ResponseBodyPipe::Http1(s) => Ok(tokio_tungstenite::WebSocketStream::from_raw_socket(
WSStream::Http1(s),
tungstenite::protocol::Role::Server,
None,
)
.await),
#[cfg(feature = "http2")]
ResponseBodyPipe::Http2(_, _) => Err(Error::WebSocketUnsupported),
#[cfg(feature = "http3")]
ResponseBodyPipe::Http3(_) => Err(Error::WebSocketUnsupported),
}
}