Skip to content

Commit 599c71d

Browse files
committed
Add a comment, verify the first magic string instead of the last one, verify all parquet parts in parquet folder.
1 parent aa507a4 commit 599c71d

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetEncryptionSuite.scala

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,28 @@ class ParquetEncryptionSuite extends QueryTest with TestHiveSingleton {
6464
}
6565
}
6666

67-
private def verifyParquetEncrypted(parquetDir: String) = {
67+
/**
68+
* Verify that the directory contains an encrypted parquet in
69+
* encrypted footer mode by means of checking for all the parquet part files
70+
* in the parquet directory that their magic string is PARE, as defined in the spec:
71+
* https://github.com/apache/parquet-format/blob/master/Encryption.md#54-encrypted-footer-mode
72+
*/
73+
private def verifyParquetEncrypted(parquetDir: String): Unit = {
6874
val parquetPartitionFiles = getListOfParquetFiles(new File(parquetDir))
6975
assert(parquetPartitionFiles.size >= 1)
70-
val parquetFile = parquetPartitionFiles.maxBy(_.lastModified)
71-
72-
val magicString = "PARE"
73-
val magicStringLength = magicString.length()
74-
val byteArray = new Array[Byte](magicStringLength)
75-
val randomAccessFile = new RandomAccessFile(parquetFile, "r")
76-
try {
77-
randomAccessFile.seek(randomAccessFile.length() - magicStringLength);
78-
randomAccessFile.read(byteArray, 0, magicStringLength)
79-
} finally {
80-
randomAccessFile.close()
76+
parquetPartitionFiles.foreach { parquetFile =>
77+
val magicString = "PARE"
78+
val magicStringLength = magicString.length()
79+
val byteArray = new Array[Byte](magicStringLength)
80+
val randomAccessFile = new RandomAccessFile(parquetFile, "r")
81+
try {
82+
randomAccessFile.read(byteArray, 0, magicStringLength)
83+
} finally {
84+
randomAccessFile.close()
85+
}
86+
val stringRead = new String(byteArray, StandardCharsets.UTF_8)
87+
assert(magicString == stringRead)
8188
}
82-
val stringRead = new String(byteArray, StandardCharsets.UTF_8)
83-
assert(magicString == stringRead)
8489
}
8590

8691
private def getListOfParquetFiles(dir: File): List[File] = {

0 commit comments

Comments
 (0)