本エントリーではKey Value Storeをtokioで作るうえで学んだことを書いていきます。
RustのLT会 Shinjuku.rs #13で話させていただいた内容です。
発表時のスライド
作ったもの
TCP(TLS)でremoteに接続してKey/ValueをCRUDするだけのserverです。
❯ export KVSD_HOST=kvsd.ymgyt.io
❯ kvsd set hello rust
OK
❯ kvsd get hello
rust
❯ kvsd set hello 'rust!!!'
OK
❯ kvsd delete hello
OK old value: rust!!!TCP Server

ClientからのTCP接続ごとにtaskをtokio::spawn()して専用のHandlerを生成します。
Key/Valueを保存するFileごとにもtaskを生成しておき、共有resourceの処理でlock等が必要ないようにします。 (図の四角がtokio::spawn()したtaskを表しています)
Max Connection
tokioを利用することで、clientごとにthreadを生成する必要がなくなったのですが無制限にtaskを生成するわけにもいかないので最大connection数を制御したいです。
最大connectionに達したときはTcpListener.accept()をよびださないように実装します。
struct SemaphoreListener {
inner: TcpListener,
max_connections: Arc<Semaphore>,
}
impl SemaphoreListener {
fn new(listener: TcpListener, max_connections: u32) -> Self {
Self {
inner: listener,
max_connections: Arc::new(Semaphore::new(max_connections as usize)),
}
}
async fn accept(&mut self) -> std::io::Result<(TcpStream, std::net::SocketAddr)> {
self.max_connections.acquire().await.forget();
self.inner.accept().await
}
}https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L494
tokio::sync::SemaphoreでTcpListenerをwrapしてaccept()時にacquire()をよぶことで最大connectionに達したときはblockするようにします。
所有権の関係でforget()を呼んでいるので、Clientとの接続が終了したときのHandlerのdrop処理で帳尻をあわせる必要があります。
struct Handler {
// ...
max_connections: Arc<Semaphore>,
}
impl Drop for Handler {
fn drop(&mut self) {
self.max_connections.add_permits(1);
}
}https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L488
Graceful shutdown


終了処理時(SIGINT等)にすべてのClientとの接続が切れるまで待機する必要があります。
Graceful shutdownに必要な処理はそれぞれ以下のtokioのAPIを利用することで実現できました。
- Signal handling =>
tokio::signal::ctrl_c() - Shutdownの通知 =>
tokio::sync::broadcast::channel() - Shutdownの待機 =>
tokio::sync::mpsc::channel()
impl Server {
pub(crate) async fn run(
mut self,
listener: TcpListener,
shutdown: impl Future,
) -> Result<()> {
tokio::select! {
result = self.serve(listener) => {
if let Err(err) = result {
error!(cause = %err, "Failed to accept");
}
}
_ = shutdown => {
info!("Shutdown signal received");
}
}
info!("Notify shutdown to all handlers");
self.graceful_shutdown.shutdown().await;
info!("Shutdown successfully completed");
Ok(())
}
}
struct GracefulShutdown {
notify_shutdown: broadcast::Sender<ShutdownSignal>,
shutdown_complete_tx: mpsc::Sender<ShutdownCompleteSignal>,
shutdown_complete_rx: mpsc::Receiver<ShutdownCompleteSignal>,
}
impl GracefulShutdown {
fn new() -> Self {
let (notify_shutdown, _) = broadcast::channel(1);
let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1);
Self {
notify_shutdown,
shutdown_complete_tx,
shutdown_complete_rx,
}
}
// Notify handlers of the shutdown and wait for it to be completed.
async fn shutdown(mut self) {
// Notify shutdown to all handler.
drop(self.notify_shutdown);
// Drop final Sender so the Receiver below can complete.
drop(self.shutdown_complete_tx);
// Wait for all handler to finish.
let _ = self.shutdown_complete_rx.recv().await;
}
}https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L204
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L159
tokio::select!を利用することで、task内で2つのfutureを同時にawaitすることができるようになります。
impl Futureとしておくことでtest時にはtokio::sync::Notifyを利用することでできます。
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/tests/integration_test.rs#L38
Stream read/write
TCP上で自分で定義したデータ構造での通信を目指します。
MessageはClient-Server間でやりとりする単位です(Authenticate, Set, Success,...)
FrameはMessageの構成要素です(String, Bytes,Null,...)
実体はredis serialization protocolの劣化版みたいなものです。
TcpStream/TlsStreamをwrapした構造体を定義して、client/serverにprotocol実装を提供します。
use bytes::{Buf, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter};
pub struct Connection<T = TcpStream> {
stream: BufWriter<T>,
// The buffer for reading frames.
buffer: BytesMut,
}bufferはtcpからreadしてMessageを構成するために利用します。
write
client/serverから通信したいMessageをうけとるとそれをFramesに分解してそれぞれserializationしていきます。
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter};
impl<T> Connection<T>
where
T: AsyncWrite + AsyncRead + Unpin,
{
pub(crate) async fn write_message(&mut self, message: impl Into<MessageFrames>) -> Result<()> {
let frames = message.into();
self.stream.write_u8(frameprefix::MESSAGE_FRAMES).await?;
self.write_decimal(frames.len()).await?;
for frame in frames {
self.write_frame(frame).await?
}
self.stream.flush().await?;
Ok(())
}
async fn write_frame(&mut self, frame: Frame) -> Result<()> {
match frame {
Frame::MessageType(mt) => {
self.stream.write_u8(frameprefix::MESSAGE_TYPE).await?;
self.stream.write_u8(mt.into()).await?;
}
Frame::String(val) => {
self.stream.write_u8(frameprefix::STRING).await?;
self.stream.write_all(val.as_bytes()).await?;
self.stream.write_all(DELIMITER).await?;
}
Frame::Bytes(val) => {
self.stream.write_u8(frameprefix::BYTES).await?;
self.write_decimal(val.len() as u64).await?;
self.stream.write_all(val.as_ref()).await?;
self.stream.write_all(DELIMITER).await?;
}
Frame::Time(val) => {
self.stream.write_u8(frameprefix::TIME).await?;
self.stream.write_all(val.to_rfc3339().as_bytes()).await?;
self.stream.write_all(DELIMITER).await?;
}
Frame::Null => {
self.stream.write_u8(frameprefix::NULL).await?;
}
}
Ok(())
}
}AsyncWriteExtにはendianを考慮したwrite methodが定義されているので便利です。
read
impl<T> Connection<T>
where
T: AsyncWrite + AsyncRead + Unpin,
{
pub(crate) async fn read_message_with_timeout(
&mut self,
duration: Duration,
) -> Result<Option<Message>> {
match tokio::time::timeout(duration, self.read_message()).await {
Ok(read_result) => read_result,
Err(elapsed) => Err(Error::from(elapsed)),
}
}
pub(crate) async fn read_message(&mut self) -> Result<Option<Message>> {
match self.read_message_frames().await? {
Some(message_frames) => Ok(Some(Message::from_frames(message_frames)?)),
None => Ok(None),
}
}
async fn read_message_frames(&mut self) -> Result<Option<MessageFrames>> {
loop {
if let Some(message_frames) = self.parse_message_frames()? {
return Ok(Some(message_frames));
}
if 0 == self.stream.read_buf(&mut self.buffer).await? {
return if self.buffer.is_empty() {
Ok(None)
} else {
Err(ErrorKind::ConnectionResetByPeer.into())
};
}
}
}
fn parse_message_frames(&mut self) -> Result<Option<MessageFrames>> {
use FrameError::Incomplete;
let mut buf = Cursor::new(&self.buffer[..]);
match MessageFrames::check_parse(&mut buf) {
Ok(_) => {
let len = buf.position() as usize;
buf.set_position(0);
let message_frames = MessageFrames::parse(&mut buf)?;
self.buffer.advance(len);
Ok(Some(message_frames))
}
Err(Incomplete) => Ok(None),
Err(e) => Err(e.into()),
}
}
}futureにtimeoutを設定するにはtokio::time::timeoutを利用します。
tokio::time::timeout()でwrapするとtimeoutつきのfutureを作れ、timeoutするとtokio::time::Elapsedを返してくれます。
tcpからreadしたときに意味ある単位でreadできているとは限らないのでbufferに読み込んだタイミングで一度parseを試みます。
parse時の実装をsimpleにするためにstd::io::Cursorでwrapしておくことで、parse側では必要な単位でbufferを読み込めます。
bufferにMessageを構成するに十分なFrameがない場合はloopでtcpのreadを繰り返します。
bufferに十分なFramesがある場合は、呼び出し元にMessageを返して、bufferをclearします。
TcpStreamのtestにはtokio::io::duplex()が便利
TcpStreamをwrapしたような型のtestではtokio::io::duplex()が便利でした。
#[test]
fn message_frames() {
tokio_test::block_on(async move {
let (client, server) = tokio::io::duplex(1024);
let mut client_conn = Connection::new(client, None);
let mut server_conn = Connection::new(server, None);
let messages: Vec<Message> = vec![
Message::Authenticate(Authenticate::new("user", "pass")),
Message::Ping(Ping::new().record_client_time()),
// ...
];
let messages_clone = messages.clone();
let write_handle = tokio::spawn(async move {
for message in messages {
client_conn.write_message(message).await.unwrap();
}
});
let read_handle = tokio::spawn(async move {
for want in messages_clone {
let got = server_conn.read_message().await.unwrap().unwrap();
assert_eq!(want, got);
}
});
write_handle.await.unwrap();
read_handle.await.unwrap();
})
}このような感じでtcp listenerを用意せずにメモリ上でtestが完結できました。
async codeのtestではruntimeを起動しておく必要があるので、tokio_testを利用しました。
TLS
TCPをTLSで保護するにはtokio_rustlsを利用できます。
server
use tokio_rustls::rustls;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
let mut tls_config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
let certs: Vec<rustls::Certificate> = self.config.load_certs();
let mut keys: Vec<rustls::PrivateKey> = self.config.load_keys();
tls_config.set_single_cert(certs, keys.remove(0)).unwrap();
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
let tls_stream: TlsStream<TcpStream> = acceptor.accept(stream).await?;https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L251
設定を生成して、鍵と証明書をloadして、TcpStreamをwrapします。
client
use tokio_rustls::client::TlsStream;
use tokio_rustls::{rustls, webpki, TlsConnector};
// tokio-rustls = { version = "0.20.0", features = ["dangerous_configuration"] }
let mut tls_config = rustls::ClientConfig::new();
tls_config
.dangerous()
.set_certificate_verifier(Arc::new(DangerousServerCertVerifier::new()));
let connector = TlsConnector::from(Arc::new(tls_config));
let domain = webpki::DNSNameRef::try_from_ascii_str(host)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid host"))?;
let stream = tokio::net::TcpStream::connect(addr).await?;
let tls_stream: TlsStream<TcpStream> = connector.connect(domain, stream).await?;clientも同様に設定を生成してTcpStreamをwrapしてやります。
local環境ではオレオレ証明書を使いたかったのでdangerous_configuration featureを有効にして証明書検証処理を自分で定義することができます。
struct DangerousServerCertVerifier {}
impl DangerousServerCertVerifier {
fn new() -> Self {
Self {}
}
}
impl rustls::ServerCertVerifier for DangerousServerCertVerifier {
fn verify_server_cert(
&self,
_roots: &rustls::RootCertStore,
_presented_certs: &[rustls::Certificate],
_dns_name: webpki::DNSNameRef<'_>,
_oscp_response: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
Ok(ServerCertVerified::assertion())
}
}https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/client/tcp.rs#L168
Shared stateをMutexからchannelで管理
[f:id:yamaguchi7073xtt:20201224185418j:plain]
Shared state(file)を排他的に利用したい場合にMutex等でlockするのではなく専用のtaskを生成してchannel経由で処理を依頼します。
処理の依頼をchannelに書き込むのはいいのですが、結果をどう受け取るのかが問題になってきます。
ここではtokio::sync::oneshotを利用しました。
use tokio::sync::oneshot;
pub(crate) struct Work<Req, Res> {
// ...
pub(crate) request: Req,
// Wrap with option so that response can be sent via mut reference.
pub(crate) response_sender: Option<oneshot::Sender<Result<Res>>>,
}https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/core/uow.rs#L27
処理を依頼する型のfieldにoneshot::Sender<T>を定義しておき呼び出し元で保持しているReceiverで結果をうけとります。
serverがclientからMessageをうけとって、Key/Valueを取得する処理は以下のようになりました。
Message::Get(get) => {
let (work: Uow, rx: Receiver<Result<_>>) = UnitOfWork::new_get(get);
self.sender.send(work).await?;
match rx.await? {
Ok(Some(value)) => {
connection.write_message(Success::with_value(value)).await?
}
Ok(None) => connection.write_message(Success::new()).await?,
_ => unreachable!(),
}
}https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L444
Fileへのappend onlyな保存
ようやくKey/Valueの話に。
永続化するKey/Valueはfileにappend onlyで保存していきます。メモリ上ではKeyとoffsetを保持しておきます。
Key/Valueをfileから読み込む処理はいかのようになりました。
use tokio::sync::mpsc::Receiver;
use tokio::io::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, SeekFrom};
pub(crate) struct Table<File = fs::File> {
file: File,
index: Index, // HashMap<String,usize>,
receiver: Receiver<UnitOfWork>,
}
impl<File> Table<File>
where
File: AsyncWrite + AsyncRead + AsyncSeek + Unpin,
{
async fn lookup_entry(&mut self, key: &Key) -> Result<Option<Entry>> {
let offset = match self.index.lookup_offset(key){
Some(offset) => offset,
None => return Ok(None),
};
let current = self.file.seek(SeekFrom::Current(0)).await?;
self.file.seek(SeekFrom::Start(offset as u64)).await?;
let (_, entry) = Entry::decode_from(&mut self.file).await?;
self.file.seek(SeekFrom::Start(current)).await?;
Ok(Some(entry))
}
}tokio::io::AsyncSeekがあるのでoffsetにseekする処理も同期と同じように書けました。
まとめ
- はじめてtokioを利用しようと思ったときmini redisの実装が非常に参考になりました。コメントもたくさんあり親切でした。
- tokioを利用してremoteと通信してデータを保存するまでを自分で作れて楽しかったです
- tokioのAPIを利用してGoに近い形でconcurrentな処理書けそうと思いました。
tokio::select!(select)tokio::spawn()(goroutine)sync::{mpsc,oneshot,broadcast}(chan)
この記事を書いている日にtokio v1.0.0がreleaseされました。
最低でも5年はメンテしていく旨も発表されておりtokioを利用しはじめるにはちょうどいい時期なのではないでしょうか。