use std::io;
use std::fmt;
use std::net::SocketAddr;
use std::pin::pin;
use s2n_quic as quic;
use s2n_quic_h3 as quic_h3;
use quic_h3::h3 as h3;
use bytes::Bytes;
use futures::Stream;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use crate::tls::{TlsConfig, Error};
use crate::listener::Endpoint;
type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;
pub struct QuicListener {
endpoint: SocketAddr,
listener: Mutex<quic::Server>,
tls: TlsConfig,
}
pub struct H3Stream(H3Conn, quic::connection::Result<SocketAddr>);
pub struct H3Connection {
pub(crate) remote: quic::connection::Result<SocketAddr>,
pub(crate) parts: http::request::Parts,
pub(crate) tx: QuicTx,
pub(crate) rx: QuicRx,
}
#[doc(hidden)]
pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);
#[doc(hidden)]
pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);
impl QuicListener {
pub async fn bind(address: SocketAddr, tls: TlsConfig) -> Result<Self, Error> {
use quic::provider::tls::rustls::Server as H3TlsServer;
let cert_chain = tls.load_certs()?
.into_iter()
.map(|v| v.to_vec())
.collect::<Vec<_>>();
let h3tls = H3TlsServer::builder()
.with_application_protocols(["h3"].into_iter())
.map_err(|e| Error::Bind(e))?
.with_certificate(cert_chain, tls.load_key()?.secret_der())
.map_err(|e| Error::Bind(e))?
.with_prefer_server_cipher_suite_order(tls.prefer_server_cipher_order)
.map_err(|e| Error::Bind(e))?
.build()
.map_err(|e| Error::Bind(e))?;
let listener = quic::Server::builder()
.with_tls(h3tls)?
.with_io(address)?
.start()
.map_err(|e| Error::Bind(Box::new(e)))?;
Ok(QuicListener {
tls,
endpoint: listener.local_addr()?,
listener: Mutex::new(listener),
})
}
}
impl QuicListener {
pub async fn accept(&self) -> Option<quic::Connection> {
self.listener
.lock().await
.accept().await
}
pub async fn connect(&self, accept: quic::Connection) -> io::Result<H3Stream> {
let remote = accept.remote_addr();
let quic_conn = quic_h3::Connection::new(accept);
let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?;
Ok(H3Stream(conn, remote))
}
pub fn endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls))
}
}
impl H3Stream {
pub async fn accept(&mut self) -> io::Result<Option<H3Connection>> {
let remote = self.1.clone();
let ((parts, _), (tx, rx)) = match self.0.accept().await {
Ok(Some((req, stream))) => (req.into_parts(), stream.split()),
Ok(None) => return Ok(None),
Err(e) => {
if matches!(e.try_get_code().map(|c| c.value()), Some(0 | 0x100)) {
return Ok(None)
}
return Err(io::Error::other(e));
}
};
Ok(Some(H3Connection { remote, parts, tx: QuicTx(tx), rx: QuicRx(rx) }))
}
}
impl QuicTx {
pub async fn send_response<S>(&mut self, response: http::Response<S>) -> io::Result<()>
where S: Stream<Item = io::Result<Bytes>>
{
let (parts, body) = response.into_parts();
let response = http::Response::from_parts(parts, ());
self.0.send_response(response).await.map_err(io::Error::other)?;
let mut body = pin!(body);
while let Some(bytes) = body.next().await {
let bytes = bytes.map_err(io::Error::other)?;
self.0.send_data(bytes).await.map_err(io::Error::other)?;
}
self.0.finish().await.map_err(io::Error::other)
}
pub fn cancel(&mut self) {
self.0.stop_stream(h3::error::Code::H3_NO_ERROR);
}
}
impl H3Connection {
pub fn endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Quic(self.remote?).assume_tls())
}
}
mod async_traits {
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use super::{Bytes, QuicRx};
use crate::listener::AsyncCancel;
use futures::Stream;
use s2n_quic_h3::h3;
impl Stream for QuicRx {
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use bytes::Buf;
match ready!(self.0.poll_recv_data(cx)) {
Ok(Some(mut buf)) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))),
Ok(None) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(io::Error::other(e)))),
}
}
}
impl AsyncCancel for QuicRx {
fn poll_cancel(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.0.stop_sending(h3::error::Code::H3_NO_ERROR);
Poll::Ready(Ok(()))
}
}
}
impl fmt::Debug for H3Stream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("H3Stream").finish()
}
}
impl fmt::Debug for H3Connection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("H3Connection").finish()
}
}