Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,30 @@ public class RsaBytesEncryptor implements BytesEncryptor {
private final RsaEncryptorType type;
private final Cipher encryptor;
private final Cipher decryptor;
private final RSAPublicKey publicKey;
private final RSAPrivateKey privateKey;

public RsaBytesEncryptor(@NonNull RsaEncryptorType type, String publicKeyBase64, String privateKeyBase64) {
this.type = type;
if (type == RsaEncryptorType.ENCRYPT_MODE) {
Validate.notBlank(publicKeyBase64, "The public key is required");
this.encryptor = createCipher(Cipher.ENCRYPT_MODE, transformToPublicKey(publicKeyBase64));
this.encryptor = createCipher();
this.decryptor = null;
this.publicKey = transformToPublicKey(publicKeyBase64);
this.privateKey = null;
} else if (type == RsaEncryptorType.DECRYPT_MODE) {
Validate.notBlank(privateKeyBase64, "The private key is required");
this.encryptor = null;
this.decryptor = createCipher(Cipher.DECRYPT_MODE, transformToPrivateKey(privateKeyBase64));
this.decryptor = createCipher();
this.privateKey = transformToPrivateKey(privateKeyBase64);
this.publicKey = null;
} else {
Validate.notBlank(publicKeyBase64, "The public key is required");
this.encryptor = createCipher(Cipher.ENCRYPT_MODE, transformToPublicKey(publicKeyBase64));
this.encryptor = createCipher();
this.publicKey = transformToPublicKey(publicKeyBase64);
Validate.notBlank(privateKeyBase64, "The private key is required");
this.decryptor = createCipher(Cipher.DECRYPT_MODE, transformToPrivateKey(privateKeyBase64));
this.decryptor = createCipher();
this.privateKey = transformToPrivateKey(privateKeyBase64);
}
}

Expand All @@ -81,6 +89,7 @@ public synchronized byte[] encrypt(byte[] origin) {
if (encryptor == null) {
throw new IllegalStateException("The encryptor is required but null");
}
initCipher(encryptor, Cipher.ENCRYPT_MODE, publicKey);
return doFinal(encryptor, origin);
}

Expand All @@ -92,6 +101,7 @@ public synchronized byte[] decrypt(byte[] encrypted) {
if (decryptor == null) {
throw new IllegalStateException("The decryptor is required but null");
}
initCipher(decryptor, Cipher.DECRYPT_MODE, privateKey);
return doFinal(decryptor, encrypted);
}

Expand All @@ -115,17 +125,21 @@ public static Pair<String, String> generateBase64EncodeKeyPair(int keySize) {
return new Pair<>(publicKeyBase64, privateKeyBase64);
}

private Cipher createCipher(int mode, Key key) {
private Cipher createCipher() {
try {
Cipher cipher = Cipher.getInstance(ALGORITHM_NAME_WITH_PADDING);
cipher.init(mode, key);
return cipher;
return Cipher.getInstance(ALGORITHM_NAME_WITH_PADDING);
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException("Not a valid encryption algorithm", e);
} catch (NoSuchPaddingException e) {
throw new IllegalStateException("Should not happen", e);
}
}

private void initCipher(Cipher cipher, int mode, Key key) {
try {
cipher.init(mode, key);
} catch (InvalidKeyException e) {
throw new IllegalArgumentException("Not a valid secret key", e);
throw new IllegalArgumentException("Not a valid key", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import com.oceanbase.odc.common.crypto.RsaBytesEncryptor.RsaEncryptorType;
import com.oceanbase.odc.common.lang.Pair;

import lombok.extern.slf4j.Slf4j;

/**
* @author gaoda.xy
* @date 2023/8/23 14:57
*/
@Slf4j
public class RsaBytesEncryptorTest extends EncryptorTest {

private static final Pair<String, String> keyPair = RsaBytesEncryptor.generateBase64EncodeKeyPair();
Expand Down Expand Up @@ -94,4 +97,23 @@ public void test_concurrentDecrypt() {
}
}

@Test
public void test_decryptInvalidInput_andThenDecryptValidInput() {
RsaBytesEncryptor wrongEncryptor = new RsaBytesEncryptor(RsaEncryptorType.ENCRYPT_MODE,
RsaBytesEncryptor.generateBase64EncodeKeyPair(4096).left, null);
Pair<String, String> keyPair = RsaBytesEncryptor.generateBase64EncodeKeyPair();
RsaBytesEncryptor encryptor = new RsaBytesEncryptor(RsaEncryptorType.ENCRYPT_MODE, keyPair.left, null);
RsaBytesEncryptor decryptor = new RsaBytesEncryptor(RsaEncryptorType.DECRYPT_MODE, null, keyPair.right);
String origin = "This is the origin string";
byte[] encrypted = wrongEncryptor.encrypt(origin.getBytes(StandardCharsets.UTF_8));
try {
decryptor.decrypt(encrypted);
} catch (Exception e) {
log.info("Decrypt failed when encrypt content using wrong public key");
}
encrypted = encryptor.encrypt(origin.getBytes(StandardCharsets.UTF_8));
String decrypted = new String(decryptor.decrypt(encrypted), StandardCharsets.UTF_8);
Assert.assertEquals(origin, decrypted);
}

}