本エントリーではKey Value Storeをtokio で作るうえで学んだことを書いていきます。 RustのLT会 Shinjuku.rs #13 で話させていただいた内容です。
発表時のスライド
source code
TCP(TLS)でremoteに接続してKey/ValueをCRUDするだけのserverです。
❯ export KVSD_HOST=kvsd.ymgyt.io
❯ kvsd set hello rust
OK
❯ kvsd get hello
rust
❯ kvsd set hello ' rust!!!'
OK
❯ kvsd delete hello
OK old value: rust!!!
概要 ClientからのTCP接続ごとにtaskをtokio::spawn()
して専用のHandlerを生成します。 Key/Valueを保存するFileごとにもtaskを生成しておき、共有resourceの処理でlock等が必要ないようにします。 (図の四角がtokio::spawn()
したtaskを表しています)
tokioを利用することで、clientごとにthreadを生成する必要がなくなったのですが無制限にtaskを生成するわけにもいかないので最大connection数を制御したいです。 最大connectionに達したときはTcpListener.accept()
をよびださないように実装します。
struct SemaphoreListener {
inner : TcpListener,
max_connections : Arc< Semaphore> ,
}
impl SemaphoreListener {
fn new ( listener : TcpListener, max_connections : u32 ) -> Self {
Self {
inner: listener,
max_connections: Arc:: new( Semaphore:: new( max_connections as usize ) ) ,
}
}
async fn accept ( & mut self ) -> std:: io:: Result < ( TcpStream, std:: net:: SocketAddr) > {
self . max_connections. acquire ( ) . await. forget ( ) ;
self . inner. accept ( ) . await
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L494
tokio::sync::Semaphore
でTcpListener
をwrapしてaccept()
時にacquire()
をよぶことで最大connectionに達したときはblockするようにします。 所有権の関係でforget()
を呼んでいるので、Clientとの接続が終了したときのHandlerのdrop処理で帳尻をあわせる必要があります。
struct Handler {
max_connections : Arc< Semaphore> ,
}
impl Drop for Handler {
fn drop ( & mut self ) {
self . max_connections. add_permits ( 1 ) ;
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L488
Shutdownの通知 Shutdownの待機 終了処理時(SIGINT等)にすべてのClientとの接続が切れるまで待機する必要があります。 Graceful shutdownに必要な処理はそれぞれ以下のtokioのAPIを利用することで実現できました。
Signal handling => tokio::signal::ctrl_c()
Shutdownの通知 => tokio::sync::broadcast::channel()
Shutdownの待機 =>tokio::sync::mpsc::channel()
impl Server {
pub ( crate ) async fn run (
mut self ,
listener : TcpListener,
shutdown : impl Future,
) -> Result < ( ) > {
tokio:: select! {
result = self . serve ( listener) => {
if let Err ( err) = result {
error! ( cause = % err, " Failed to accept" ) ;
}
}
_ = shutdown => {
info! ( " Shutdown signal received" ) ;
}
}
info! ( " Notify shutdown to all handlers" ) ;
self . graceful_shutdown. shutdown ( ) . await;
info! ( " Shutdown successfully completed" ) ;
Ok ( ( ) )
}
}
struct GracefulShutdown {
notify_shutdown : broadcast:: Sender< ShutdownSignal> ,
shutdown_complete_tx : mpsc:: Sender< ShutdownCompleteSignal> ,
shutdown_complete_rx : mpsc:: Receiver< ShutdownCompleteSignal> ,
}
impl GracefulShutdown {
fn new ( ) -> Self {
let ( notify_shutdown, _ ) = broadcast:: channel( 1 ) ;
let ( shutdown_complete_tx, shutdown_complete_rx) = mpsc:: channel( 1 ) ;
Self {
notify_shutdown,
shutdown_complete_tx,
shutdown_complete_rx,
}
}
async fn shutdown ( mut self ) {
drop ( self . notify_shutdown) ;
drop ( self . shutdown_complete_tx) ;
let _ = self . shutdown_complete_rx. recv ( ) . await;
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L204 https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L159
tokio::select!
を利用することで、task内で2つのfutureを同時にawait
することができるようになります。 impl Future
としておくことでtest時にはtokio::sync::Notify
を利用することでできます。 https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/tests/integration_test.rs#L38
TCP上で自分で定義したデータ構造での通信を目指します。 Message
はClient-Server間でやりとりする単位です(Authenticate
, Set
, Success
,...) Frame
はMessage
の構成要素です(String
, Bytes
,Null
,...) 実体はredis serialization protocolの劣化版みたいなものです。
TcpStream
/TlsStream
をwrapした構造体を定義して、client/serverにprotocol実装を提供します。
use bytes:: { Buf, BytesMut} ;
use tokio:: io:: { AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter} ;
pub struct Connection < T = TcpStream> {
stream : BufWriter< T> ,
buffer : BytesMut,
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/protocol/connection/mod.rs#L12
buffer
はtcpからreadしてMessageを構成するために利用します。
client/serverから通信したいMessage
をうけとるとそれをFrames
に分解してそれぞれserializationしていきます。
use tokio:: io:: { AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter} ;
impl < T> Connection < T>
where
T: AsyncWrite + AsyncRead + Unpin,
{
pub ( crate ) async fn write_message ( & mut self , message : impl Into < MessageFrames> ) -> Result < ( ) > {
let frames = message. into ( ) ;
self . stream. write_u8 ( frameprefix:: MESSAGE_FRAMES ) . await? ;
self . write_decimal ( frames. len ( ) ) . await? ;
for frame in frames {
self . write_frame ( frame) . await?
}
self . stream. flush ( ) . await? ;
Ok ( ( ) )
}
async fn write_frame ( & mut self , frame : Frame) -> Result < ( ) > {
match frame {
Frame:: MessageType( mt) => {
self . stream. write_u8 ( frameprefix:: MESSAGE_TYPE ) . await? ;
self . stream. write_u8 ( mt. into ( ) ) . await? ;
}
Frame:: String( val) => {
self . stream. write_u8 ( frameprefix:: STRING ) . await? ;
self . stream. write_all ( val. as_bytes ( ) ) . await? ;
self . stream. write_all ( DELIMITER ) . await? ;
}
Frame:: Bytes( val) => {
self . stream. write_u8 ( frameprefix:: BYTES ) . await? ;
self . write_decimal ( val. len ( ) as u64 ) . await? ;
self . stream. write_all ( val. as_ref ( ) ) . await? ;
self . stream. write_all ( DELIMITER ) . await? ;
}
Frame:: Time( val) => {
self . stream. write_u8 ( frameprefix:: TIME ) . await? ;
self . stream. write_all ( val. to_rfc3339 ( ) . as_bytes ( ) ) . await? ;
self . stream. write_all ( DELIMITER ) . await? ;
}
Frame:: Null => {
self . stream. write_u8 ( frameprefix:: NULL ) . await? ;
}
}
Ok ( ( ) )
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/protocol/connection/mod.rs#L29
AsyncWriteExt
にはendianを考慮したwrite methodが定義されているので便利です。
[f:id:yamaguchi7073xtt:20201224185344j:plain]readの概要 impl < T> Connection < T>
where
T: AsyncWrite + AsyncRead + Unpin,
{
pub ( crate ) async fn read_message_with_timeout (
& mut self ,
duration : Duration,
) -> Result < Option < Message> > {
match tokio:: time:: timeout( duration, self . read_message ( ) ) . await {
Ok ( read_result) => read_result,
Err ( elapsed) => Err ( Error:: from( elapsed) ) ,
}
}
pub ( crate ) async fn read_message ( & mut self ) -> Result < Option < Message> > {
match self . read_message_frames ( ) . await? {
Some ( message_frames) => Ok ( Some ( Message:: from_frames( message_frames) ? ) ) ,
None => Ok ( None ) ,
}
}
async fn read_message_frames ( & mut self ) -> Result < Option < MessageFrames> > {
loop {
if let Some ( message_frames) = self . parse_message_frames ( ) ? {
return Ok ( Some ( message_frames) ) ;
}
if 0 == self . stream. read_buf ( & mut self . buffer) . await? {
return if self . buffer. is_empty ( ) {
Ok ( None )
} else {
Err ( ErrorKind:: ConnectionResetByPeer. into ( ) )
} ;
}
}
}
fn parse_message_frames ( & mut self ) -> Result < Option < MessageFrames> > {
use FrameError:: Incomplete;
let mut buf = Cursor:: new( & self . buffer[ .. ] ) ;
match MessageFrames:: check_parse( & mut buf) {
Ok ( _ ) => {
let len = buf. position ( ) as usize ;
buf. set_position ( 0 ) ;
let message_frames = MessageFrames:: parse( & mut buf) ? ;
self . buffer. advance ( len) ;
Ok ( Some ( message_frames) )
}
Err ( Incomplete) => Ok ( None ) ,
Err ( e) => Err ( e. into ( ) ) ,
}
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/protocol/connection/mod.rs#L73
futureにtimeoutを設定するにはtokio::time::timeout
を利用します。 tokio::time::timeout()
でwrapするとtimeoutつきのfutureを作れ、timeoutするとtokio::time::Elapsed
を返してくれます。
tcpからreadしたときに意味ある単位でreadできているとは限らないのでbufferに読み込んだタイミングで一度parseを試みます。 parse時の実装をsimpleにするためにstd::io::Cursor
でwrapしておくことで、parse側では必要な単位でbufferを読み込めます。 bufferにMessage
を構成するに十分なFrame
がない場合はloopでtcpのreadを繰り返します。 bufferに十分なFrames
がある場合は、呼び出し元にMessage
を返して、bufferをclearします。
TcpStream
をwrapしたような型のtestではtokio::io::duplex()
が便利でした。
# [ test ]
fn message_frames ( ) {
tokio_test:: block_on( async move {
let ( client, server) = tokio:: io:: duplex( 1024 ) ;
let mut client_conn = Connection:: new( client, None ) ;
let mut server_conn = Connection:: new( server, None ) ;
let messages: Vec < Message> = vec! [
Message:: Authenticate( Authenticate:: new( " user" , " pass" ) ) ,
Message:: Ping( Ping:: new( ) . record_client_time ( ) ) ,
] ;
let messages_clone = messages. clone ( ) ;
let write_handle = tokio:: spawn( async move {
for message in messages {
client_conn. write_message ( message) . await. unwrap ( ) ;
}
} ) ;
let read_handle = tokio:: spawn( async move {
for want in messages_clone {
let got = server_conn. read_message ( ) . await. unwrap ( ) . unwrap ( ) ;
assert_eq! ( want, got) ;
}
} ) ;
write_handle. await. unwrap ( ) ;
read_handle. await. unwrap ( ) ;
} )
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/protocol/connection/mod.rs#L147
このような感じでtcp listenerを用意せずにメモリ上でtestが完結できました。 async codeのtestではruntimeを起動しておく必要があるので、tokio_test
を利用しました。
TCPをTLSで保護するにはtokio_rustls
を利用できます。
use tokio_rustls:: rustls;
use tokio_rustls:: server:: TlsStream;
use tokio_rustls:: TlsAcceptor;
let mut tls_config = rustls:: ServerConfig:: new( rustls:: NoClientAuth:: new( ) ) ;
let certs: Vec < rustls:: Certificate> = self . config. load_certs ( ) ;
let mut keys: Vec < rustls:: PrivateKey> = self . config. load_keys ( ) ;
tls_config. set_single_cert ( certs, keys. remove ( 0 ) ) . unwrap ( ) ;
let tls_acceptor = TlsAcceptor:: from( Arc:: new( tls_config) ) ;
let tls_stream: TlsStream< TcpStream> = acceptor. accept ( stream) . await? ;
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L251
設定を生成して、鍵と証明書をloadして、TcpStream
をwrapします。
use tokio_rustls:: client:: TlsStream;
use tokio_rustls:: { rustls, webpki, TlsConnector} ;
let mut tls_config = rustls:: ClientConfig:: new( ) ;
tls_config
. dangerous ( )
. set_certificate_verifier ( Arc:: new( DangerousServerCertVerifier:: new( ) ) ) ;
let connector = TlsConnector:: from( Arc:: new( tls_config) ) ;
let domain = webpki:: DNSNameRef:: try_from_ascii_str( host)
. map_err ( | _| io:: Error:: new( io:: ErrorKind:: InvalidInput, " invalid host" ) ) ? ;
let stream = tokio:: net:: TcpStream:: connect( addr) . await? ;
let tls_stream: TlsStream< TcpStream> = connector. connect ( domain, stream) . await? ;
clientも同様に設定を生成してTcpStream
をwrapしてやります。 local環境ではオレオレ証明書を使いたかったのでdangerous_configuration
featureを有効にして証明書検証処理を自分で定義することができます。
struct DangerousServerCertVerifier { }
impl DangerousServerCertVerifier {
fn new ( ) -> Self {
Self { }
}
}
impl rustls:: ServerCertVerifier for DangerousServerCertVerifier {
fn verify_server_cert (
& self ,
_roots : & rustls:: RootCertStore,
_presented_certs : & [rustls:: Certificate],
_dns_name : webpki:: DNSNameRef< '_ > ,
_oscp_response : & [u8 ],
) -> Result < rustls:: ServerCertVerified, rustls:: TLSError> {
Ok ( ServerCertVerified:: assertion( ) )
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/client/tcp.rs#L168
[f:id:yamaguchi7073xtt:20201224185418j:plain]
Shared state(file)を排他的に利用したい場合にMutex等でlockするのではなく専用のtaskを生成してchannel経由で処理を依頼します。 処理の依頼をchannelに書き込むのはいいのですが、結果をどう受け取るのかが問題になってきます。 ここではtokio::sync::oneshot
を利用しました。
use tokio:: sync:: oneshot;
pub ( crate ) struct Work < Req, Res> {
pub ( crate ) request : Req,
pub ( crate ) response_sender : Option < oneshot:: Sender< Result < Res> > > ,
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/core/uow.rs#L27
処理を依頼する型のfieldにoneshot::Sender<T>
を定義しておき呼び出し元で保持しているReceiver
で結果をうけとります。 serverがclientからMessage
をうけとって、Key/Valueを取得する処理は以下のようになりました。
Message:: Get( get) => {
let ( work: Uow, rx: Receiver< Result < _ > > ) = UnitOfWork:: new_get( get) ;
self . sender. send ( work) . await? ;
match rx. await? {
Ok ( Some ( value) ) => {
connection. write_message ( Success:: with_value( value) ) . await?
}
Ok ( None ) => connection. write_message ( Success:: new( ) ) . await? ,
_ => unreachable! ( ) ,
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/server/tcp.rs#L444
[f:id:yamaguchi7073xtt:20201224185446j:plain]Key Valueの持ち方 ようやくKey/Valueの話に。 永続化するKey/Valueはfileにappend onlyで保存していきます。メモリ上ではKeyとoffsetを保持しておきます。 Key/Valueをfileから読み込む処理はいかのようになりました。
use tokio:: sync:: mpsc:: Receiver;
use tokio:: io:: { AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, SeekFrom} ;
pub ( crate ) struct Table < File = fs:: File> {
file : File,
index : Index, receiver : Receiver< UnitOfWork> ,
}
impl < File> Table < File>
where
File: AsyncWrite + AsyncRead + AsyncSeek + Unpin,
{
async fn lookup_entry ( & mut self , key : & Key) -> Result < Option < Entry> > {
let offset = match self . index. lookup_offset ( key) {
Some ( offset) => offset,
None => return Ok ( None ) ,
} ;
let current = self . file. seek ( SeekFrom:: Current( 0 ) ) . await? ;
self . file. seek ( SeekFrom:: Start( offset as u64 ) ) . await? ;
let ( _ , entry) = Entry:: decode_from( & mut self . file) . await? ;
self . file. seek ( SeekFrom:: Start( current) ) . await? ;
Ok ( Some ( entry) )
}
}
https://github.com/ymgyt/kvsd/blob/6b40c6a3edad0f416631f7ae66674fb5d7922cd7/src/core/table/table.rs#L135
tokio::io::AsyncSeek
があるのでoffsetにseekする処理も同期と同じように書けました。
はじめてtokioを利用しようと思ったときmini redis の実装が非常に参考になりました。コメントもたくさんあり親切でした。 tokioを利用してremoteと通信してデータを保存するまでを自分で作れて楽しかったです tokioのAPIを利用してGoに近い形でconcurrentな処理書けそうと思いました。 tokio::select!
(select
)tokio::spawn()
(goroutine)sync::{mpsc,oneshot,broadcast}
(chan
) この記事を書いている日にtokio v1.0.0
がreleaseされました。 最低でも5年はメンテしていく旨 も発表されておりtokioを利用しはじめるにはちょうどいい時期なのではないでしょうか。