diff --git a/server/odc-common/src/main/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptor.java b/server/odc-common/src/main/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptor.java index 281c140a8f..87d34fcb1d 100644 --- a/server/odc-common/src/main/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptor.java +++ b/server/odc-common/src/main/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptor.java @@ -54,22 +54,30 @@ public class RsaBytesEncryptor implements BytesEncryptor { private final RsaEncryptorType type; private final Cipher encryptor; private final Cipher decryptor; + private final RSAPublicKey publicKey; + private final RSAPrivateKey privateKey; public RsaBytesEncryptor(@NonNull RsaEncryptorType type, String publicKeyBase64, String privateKeyBase64) { this.type = type; if (type == RsaEncryptorType.ENCRYPT_MODE) { Validate.notBlank(publicKeyBase64, "The public key is required"); - this.encryptor = createCipher(Cipher.ENCRYPT_MODE, transformToPublicKey(publicKeyBase64)); + this.encryptor = createCipher(); this.decryptor = null; + this.publicKey = transformToPublicKey(publicKeyBase64); + this.privateKey = null; } else if (type == RsaEncryptorType.DECRYPT_MODE) { Validate.notBlank(privateKeyBase64, "The private key is required"); this.encryptor = null; - this.decryptor = createCipher(Cipher.DECRYPT_MODE, transformToPrivateKey(privateKeyBase64)); + this.decryptor = createCipher(); + this.privateKey = transformToPrivateKey(privateKeyBase64); + this.publicKey = null; } else { Validate.notBlank(publicKeyBase64, "The public key is required"); - this.encryptor = createCipher(Cipher.ENCRYPT_MODE, transformToPublicKey(publicKeyBase64)); + this.encryptor = createCipher(); + this.publicKey = transformToPublicKey(publicKeyBase64); Validate.notBlank(privateKeyBase64, "The private key is required"); - this.decryptor = createCipher(Cipher.DECRYPT_MODE, transformToPrivateKey(privateKeyBase64)); + this.decryptor = createCipher(); + this.privateKey = transformToPrivateKey(privateKeyBase64); } } @@ -81,6 +89,7 @@ public synchronized byte[] encrypt(byte[] origin) { if (encryptor == null) { throw new IllegalStateException("The encryptor is required but null"); } + initCipher(encryptor, Cipher.ENCRYPT_MODE, publicKey); return doFinal(encryptor, origin); } @@ -92,6 +101,7 @@ public synchronized byte[] decrypt(byte[] encrypted) { if (decryptor == null) { throw new IllegalStateException("The decryptor is required but null"); } + initCipher(decryptor, Cipher.DECRYPT_MODE, privateKey); return doFinal(decryptor, encrypted); } @@ -115,17 +125,21 @@ public static Pair generateBase64EncodeKeyPair(int keySize) { return new Pair<>(publicKeyBase64, privateKeyBase64); } - private Cipher createCipher(int mode, Key key) { + private Cipher createCipher() { try { - Cipher cipher = Cipher.getInstance(ALGORITHM_NAME_WITH_PADDING); - cipher.init(mode, key); - return cipher; + return Cipher.getInstance(ALGORITHM_NAME_WITH_PADDING); } catch (NoSuchAlgorithmException e) { throw new IllegalArgumentException("Not a valid encryption algorithm", e); } catch (NoSuchPaddingException e) { throw new IllegalStateException("Should not happen", e); + } + } + + private void initCipher(Cipher cipher, int mode, Key key) { + try { + cipher.init(mode, key); } catch (InvalidKeyException e) { - throw new IllegalArgumentException("Not a valid secret key", e); + throw new IllegalArgumentException("Not a valid key", e); } } diff --git a/server/odc-common/src/test/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptorTest.java b/server/odc-common/src/test/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptorTest.java index 125a7f31e9..d081ebb37d 100644 --- a/server/odc-common/src/test/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptorTest.java +++ b/server/odc-common/src/test/java/com/oceanbase/odc/common/crypto/RsaBytesEncryptorTest.java @@ -24,10 +24,13 @@ import com.oceanbase.odc.common.crypto.RsaBytesEncryptor.RsaEncryptorType; import com.oceanbase.odc.common.lang.Pair; +import lombok.extern.slf4j.Slf4j; + /** * @author gaoda.xy * @date 2023/8/23 14:57 */ +@Slf4j public class RsaBytesEncryptorTest extends EncryptorTest { private static final Pair keyPair = RsaBytesEncryptor.generateBase64EncodeKeyPair(); @@ -94,4 +97,23 @@ public void test_concurrentDecrypt() { } } + @Test + public void test_decryptInvalidInput_andThenDecryptValidInput() { + RsaBytesEncryptor wrongEncryptor = new RsaBytesEncryptor(RsaEncryptorType.ENCRYPT_MODE, + RsaBytesEncryptor.generateBase64EncodeKeyPair(4096).left, null); + Pair keyPair = RsaBytesEncryptor.generateBase64EncodeKeyPair(); + RsaBytesEncryptor encryptor = new RsaBytesEncryptor(RsaEncryptorType.ENCRYPT_MODE, keyPair.left, null); + RsaBytesEncryptor decryptor = new RsaBytesEncryptor(RsaEncryptorType.DECRYPT_MODE, null, keyPair.right); + String origin = "This is the origin string"; + byte[] encrypted = wrongEncryptor.encrypt(origin.getBytes(StandardCharsets.UTF_8)); + try { + decryptor.decrypt(encrypted); + } catch (Exception e) { + log.info("Decrypt failed when encrypt content using wrong public key"); + } + encrypted = encryptor.encrypt(origin.getBytes(StandardCharsets.UTF_8)); + String decrypted = new String(decryptor.decrypt(encrypted), StandardCharsets.UTF_8); + Assert.assertEquals(origin, decrypted); + } + }