rocket/listener/
quic.rs
1use std::io;
27use std::fmt;
28use std::net::SocketAddr;
29use std::pin::pin;
30
31use s2n_quic as quic;
32use s2n_quic_h3 as quic_h3;
33use quic_h3::h3 as h3;
34
35use bytes::Bytes;
36use futures::Stream;
37use tokio::sync::Mutex;
38use tokio_stream::StreamExt;
39
40use crate::tls::{TlsConfig, Error};
41use crate::listener::Endpoint;
42
43type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;
44
45pub struct QuicListener {
46 endpoint: SocketAddr,
47 listener: Mutex<quic::Server>,
48 tls: TlsConfig,
49}
50
51pub struct H3Stream(H3Conn, quic::connection::Result<SocketAddr>);
52
53pub struct H3Connection {
54 pub(crate) remote: quic::connection::Result<SocketAddr>,
55 pub(crate) parts: http::request::Parts,
56 pub(crate) tx: QuicTx,
57 pub(crate) rx: QuicRx,
58}
59
60#[doc(hidden)]
61pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);
62
63#[doc(hidden)]
64pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);
65
66impl QuicListener {
67 pub async fn bind(address: SocketAddr, tls: TlsConfig) -> Result<Self, Error> {
68 use quic::provider::tls::rustls::Server as H3TlsServer;
69
70 let cert_chain = tls.load_certs()?
71 .into_iter()
72 .map(|v| v.to_vec())
73 .collect::<Vec<_>>();
74
75 let h3tls = H3TlsServer::builder()
76 .with_application_protocols(["h3"].into_iter())
77 .map_err(|e| Error::Bind(e))?
78 .with_certificate(cert_chain, tls.load_key()?.secret_der())
79 .map_err(|e| Error::Bind(e))?
80 .with_prefer_server_cipher_suite_order(tls.prefer_server_cipher_order)
81 .map_err(|e| Error::Bind(e))?
82 .build()
83 .map_err(|e| Error::Bind(e))?;
84
85 let listener = quic::Server::builder()
86 .with_tls(h3tls)?
87 .with_io(address)?
88 .start()
89 .map_err(|e| Error::Bind(Box::new(e)))?;
90
91 Ok(QuicListener {
92 tls,
93 endpoint: listener.local_addr()?,
94 listener: Mutex::new(listener),
95 })
96 }
97}
98
99impl QuicListener {
100 pub async fn accept(&self) -> Option<quic::Connection> {
101 self.listener
102 .lock().await
103 .accept().await
104 }
105
106 pub async fn connect(&self, accept: quic::Connection) -> io::Result<H3Stream> {
107 let remote = accept.remote_addr();
108 let quic_conn = quic_h3::Connection::new(accept);
109 let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?;
110 Ok(H3Stream(conn, remote))
111 }
112
113 pub fn endpoint(&self) -> io::Result<Endpoint> {
114 Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls))
115 }
116}
117
118impl H3Stream {
119 pub async fn accept(&mut self) -> io::Result<Option<H3Connection>> {
120 let remote = self.1.clone();
121 let ((parts, _), (tx, rx)) = match self.0.accept().await {
122 Ok(Some((req, stream))) => (req.into_parts(), stream.split()),
123 Ok(None) => return Ok(None),
124 Err(e) => {
125 if matches!(e.try_get_code().map(|c| c.value()), Some(0 | 0x100)) {
126 return Ok(None)
127 }
128
129 return Err(io::Error::other(e));
130 }
131 };
132
133 Ok(Some(H3Connection { remote, parts, tx: QuicTx(tx), rx: QuicRx(rx) }))
134 }
135}
136
137impl QuicTx {
138 pub async fn send_response<S>(&mut self, response: http::Response<S>) -> io::Result<()>
139 where S: Stream<Item = io::Result<Bytes>>
140 {
141 let (parts, body) = response.into_parts();
142 let response = http::Response::from_parts(parts, ());
143 self.0.send_response(response).await.map_err(io::Error::other)?;
144
145 let mut body = pin!(body);
146 while let Some(bytes) = body.next().await {
147 let bytes = bytes.map_err(io::Error::other)?;
148 self.0.send_data(bytes).await.map_err(io::Error::other)?;
149 }
150
151 self.0.finish().await.map_err(io::Error::other)
152 }
153
154 pub fn cancel(&mut self) {
155 self.0.stop_stream(h3::error::Code::H3_NO_ERROR);
156 }
157}
158
159impl H3Connection {
161 pub fn endpoint(&self) -> io::Result<Endpoint> {
162 Ok(Endpoint::Quic(self.remote?).assume_tls())
163 }
164}
165
166mod async_traits {
167 use std::io;
168 use std::pin::Pin;
169 use std::task::{ready, Context, Poll};
170
171 use super::{Bytes, QuicRx};
172 use crate::listener::AsyncCancel;
173
174 use futures::Stream;
175 use s2n_quic_h3::h3;
176
177 impl Stream for QuicRx {
178 type Item = io::Result<Bytes>;
179
180 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181 use bytes::Buf;
182
183 match ready!(self.0.poll_recv_data(cx)) {
184 Ok(Some(mut buf)) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))),
185 Ok(None) => Poll::Ready(None),
186 Err(e) => Poll::Ready(Some(Err(io::Error::other(e)))),
187 }
188 }
189 }
190
191 impl AsyncCancel for QuicRx {
192 fn poll_cancel(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
193 self.0.stop_sending(h3::error::Code::H3_NO_ERROR);
194 Poll::Ready(Ok(()))
195 }
196 }
197}
198
199impl fmt::Debug for H3Stream {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_tuple("H3Stream").finish()
202 }
203}
204
205impl fmt::Debug for H3Connection {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 f.debug_struct("H3Connection").finish()
208 }
209}