@@ -622,13 +622,22 @@ impl<B: MysqlShim<W>, R: Read, W: Write> MysqlIntermediary<B, R, W> {
622622 self . writer . flush ( ) ?;
623623 return Err ( io:: Error :: new ( io:: ErrorKind :: PermissionDenied , err_msg) . into ( ) ) ;
624624 }
625+
626+ if let Some ( Ok ( db) ) = handshake. db . as_ref ( ) . map ( |x| std:: str:: from_utf8 ( x) ) {
627+ let w = InitWriter {
628+ client_capabilities : self . client_capabilities ,
629+ writer : & mut self . writer ,
630+ } ;
631+ self . shim . on_init ( db, w) ?;
632+ } else {
633+ writers:: write_ok_packet (
634+ & mut self . writer ,
635+ self . client_capabilities ,
636+ OkResponse :: default ( ) ,
637+ ) ?;
638+ }
625639 }
626640
627- writers:: write_ok_packet (
628- & mut self . writer ,
629- self . client_capabilities ,
630- OkResponse :: default ( ) ,
631- ) ?;
632641 self . writer . flush ( ) ?;
633642
634643 Ok ( ( ) )
@@ -999,13 +1008,22 @@ impl<B: AsyncMysqlShim<Cursor<Vec<u8>>> + Send + Sync, S: AsyncRead + AsyncWrite
9991008 self . writer_flush ( ) . await ?;
10001009 return Err ( io:: Error :: new ( io:: ErrorKind :: PermissionDenied , err_msg) . into ( ) ) ;
10011010 }
1011+
1012+ if let Some ( Ok ( db) ) = handshake. db . as_ref ( ) . map ( |x| std:: str:: from_utf8 ( x) ) {
1013+ let w = InitWriter {
1014+ client_capabilities : self . client_capabilities ,
1015+ writer : & mut self . writer ,
1016+ } ;
1017+ self . shim . on_init ( db, w) . await ?;
1018+ } else {
1019+ writers:: write_ok_packet (
1020+ & mut self . writer ,
1021+ self . client_capabilities ,
1022+ OkResponse :: default ( ) ,
1023+ ) ?;
1024+ }
10021025 }
10031026
1004- writers:: write_ok_packet (
1005- & mut self . writer ,
1006- self . client_capabilities ,
1007- OkResponse :: default ( ) ,
1008- ) ?;
10091027 self . writer_flush ( ) . await ?;
10101028
10111029 Ok ( ( ) )
0 commit comments