Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 60 additions & 20 deletions src/main/scala/com/johnsnowlabs/client/CloudResources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ object CloudResources {
if (!modelExists) {
val destination =
unzipInExternalCloudStorage(sourceS3URI, cachePath, azureClient, zippedModel)
modelPath = Some(transformURIToWASB(destination))
modelPath = buildAzureModelPath(cachePath, destination)
} else {
modelPath = Some(transformURIToWASB(cachePath + "/" + modelName))
modelPath = buildAzureModelPathWhenExists(cachePath, modelName)
}

modelPath
Expand All @@ -90,6 +90,24 @@ object CloudResources {

}

private def buildAzureModelPath(cachePath: String, destination: String): Option[String] = {
if (CloudHelper.isFabricAbfss(cachePath)) {
Some(destination)
} else {
Some(transformURIToWASB(destination))
}
}

private def buildAzureModelPathWhenExists(
cachePath: String,
modelName: String): Option[String] = {
if (CloudHelper.isFabricAbfss(cachePath)) {
Some(cachePath + "/" + modelName)
} else {
Some(transformURIToWASB(cachePath + "/" + modelName))
}
}

private def doesModelExistInExternalCloudStorage(
modelName: String,
destinationURI: String,
Expand All @@ -111,11 +129,18 @@ object CloudResources {
gcpClient.doesBucketPathExist(destinationBucketName, modelPath)
}
case azureClient: AzureClient => {
val (destinationBucketName, destinationStoragePath) =
CloudHelper.parseAzureBlobURI(destinationURI)
val modelPath = destinationStoragePath + "/" + modelName
if (CloudHelper.isFabricAbfss(destinationURI)) {
val fabricUri =
if (destinationURI.endsWith("/")) destinationURI + modelName
else destinationURI + "/" + modelName
azureClient.doesBucketPathExist(fabricUri, "")
} else {
val (destinationBucketName, destinationStoragePath) =
CloudHelper.parseAzureBlobURI(destinationURI)
val modelPath = destinationStoragePath + "/" + modelName

azureClient.doesBucketPathExist(destinationBucketName, modelPath)
azureClient.doesBucketPathExist(destinationBucketName, modelPath)
}
}
}

Expand All @@ -133,7 +158,6 @@ object CloudResources {
val zipFile = sourceKey.split("/").last
val modelName = zipFile.substring(0, zipFile.indexOf(".zip"))

println(s"Uploading model $modelName to external Cloud Storage URI: $destinationStorageURI")
while (zipEntry != null) {
if (!zipEntry.isDirectory) {
val outputStream = new ByteArrayOutputStream()
Expand Down Expand Up @@ -165,16 +189,23 @@ object CloudResources {
inputStream)
}
case azureClient: AzureClient => {
val (destinationBucketName, destinationStoragePath) =
CloudHelper.parseAzureBlobURI(destinationStorageURI)

val destinationAzureStoragePath =
s"$destinationStoragePath/$modelName/${zipEntry.getName}".stripPrefix("/")

azureClient.copyFileToBucket(
destinationBucketName,
destinationAzureStoragePath,
inputStream)
if (CloudHelper.isFabricAbfss(destinationStorageURI)) {
val fabricUri = (if (destinationStorageURI.endsWith("/")) destinationStorageURI
else destinationStorageURI + "/") +
s"$modelName/${zipEntry.getName}"
azureClient.copyFileToBucket(fabricUri, "", inputStream)
} else {
val (destinationBucketName, destinationStoragePath) =
CloudHelper.parseAzureBlobURI(destinationStorageURI)

val destinationAzureStoragePath =
s"$destinationStoragePath/$modelName/${zipEntry.getName}".stripPrefix("/")

azureClient.copyFileToBucket(
destinationBucketName,
destinationAzureStoragePath,
inputStream)
}
}
}

Expand Down Expand Up @@ -286,9 +317,18 @@ object CloudResources {
Paths.get(directory, keyPrefix).toUri
}
case azureClient: AzureClient => {
val (bucketName, keyPrefix) = CloudHelper.parseAzureBlobURI(bucketURI)
azureClient.downloadFilesFromBucketToDirectory(bucketName, keyPrefix, directory, isIndex)
Paths.get(directory, keyPrefix).toUri
if (CloudHelper.isFabricAbfss(bucketURI)) {
azureClient.downloadFilesFromBucketToDirectory(bucketURI, "", directory, isIndex)
Paths.get(directory).toUri
} else {
val (bucketName, keyPrefix) = CloudHelper.parseAzureBlobURI(bucketURI)
azureClient.downloadFilesFromBucketToDirectory(
bucketName,
keyPrefix,
directory,
isIndex)
Paths.get(directory, keyPrefix).toUri
}
}
}

Expand Down
18 changes: 12 additions & 6 deletions src/main/scala/com/johnsnowlabs/client/azure/AzureClient.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.johnsnowlabs.client.azure

import com.johnsnowlabs.client.util.CloudHelper
import com.johnsnowlabs.client.{CloudClient, CloudStorage}
import com.johnsnowlabs.util.ConfigHelper

Expand All @@ -8,12 +9,17 @@ class AzureClient(parameters: Map[String, String] = Map.empty) extends CloudClie
private lazy val azureStorageConnection = cloudConnect()

override protected def cloudConnect(): CloudStorage = {
val storageAccountName = parameters.getOrElse(
"storageAccountName",
throw new Exception("Azure client requires storageAccountName"))
val accountKey =
parameters.getOrElse("accountKey", ConfigHelper.getHadoopAzureConfig(storageAccountName))
new AzureGateway(storageAccountName, accountKey)
if (CloudHelper.isMicrosoftFabric) {
// These params are NOT required for Fabric
new AzureGateway("", "", isFabricLakehouse = true)
} else {
val storageAccountName = parameters.getOrElse(
"storageAccountName",
throw new Exception("Azure client requires storageAccountName"))
val accountKey =
parameters.getOrElse("accountKey", ConfigHelper.getHadoopAzureConfig(storageAccountName))
new AzureGateway(storageAccountName, accountKey)
}
}

override def doesBucketPathExist(bucketName: String, filePath: String): Boolean = {
Expand Down
193 changes: 131 additions & 62 deletions src/main/scala/com/johnsnowlabs/client/azure/AzureGateway.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ import com.johnsnowlabs.nlp.util.io.ResourceHelper
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.IOUtils

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, InputStream}
import java.io._
import scala.jdk.CollectionConverters.asScalaIteratorConverter

class AzureGateway(storageAccountName: String, accountKey: String) extends CloudStorage {
class AzureGateway(
storageAccountName: String,
accountKey: String,
isFabricLakehouse: Boolean = false)
extends CloudStorage {

private lazy val blobServiceClient: BlobServiceClient = {
val connectionString =
Expand All @@ -23,93 +27,158 @@ class AzureGateway(storageAccountName: String, accountKey: String) extends Cloud
blobServiceClient
}

private def getHadoopFS(path: String): FileSystem = {
val uri = new java.net.URI(path)
FileSystem.get(uri, ResourceHelper.spark.sparkContext.hadoopConfiguration)
}

override def doesBucketPathExist(bucketName: String, filePath: String): Boolean = {
val blobContainerClient = blobServiceClient
.getBlobContainerClient(bucketName)
if (isFabricLakehouse) {
doesPathExistAbfss(bucketName)
} else {
val blobContainerClient = blobServiceClient
.getBlobContainerClient(bucketName)

val prefix = if (filePath.endsWith("/")) filePath else filePath + "/"
val prefix = if (filePath.endsWith("/")) filePath else filePath + "/"

val blobs = blobContainerClient
.listBlobs()
.iterator()
.asScala
.filter(_.getName.startsWith(prefix))
val blobs = blobContainerClient
.listBlobs()
.iterator()
.asScala
.filter(_.getName.startsWith(prefix))

blobs.nonEmpty
blobs.nonEmpty
}
}

override def copyFileToBucket(
bucketName: String,
destinationPath: String,
inputStream: InputStream): Unit = {

val blockBlobClient = blobServiceClient
.getBlobContainerClient(bucketName)
.getBlobClient(destinationPath)
.getBlockBlobClient

val streamSize = inputStream.available()
blockBlobClient.upload(inputStream, streamSize)
if (isFabricLakehouse) {
copyInputStreamToAbfssUri(bucketName, inputStream)
} else {
val blockBlobClient = blobServiceClient
.getBlobContainerClient(bucketName)
.getBlobClient(destinationPath)
.getBlockBlobClient

val streamSize = inputStream.available()
blockBlobClient.upload(inputStream, streamSize)
}
}

override def copyInputStreamToBucket(
bucketName: String,
filePath: String,
sourceFilePath: String): Unit = {
val fileSystem = FileSystem.get(ResourceHelper.spark.sparkContext.hadoopConfiguration)
val inputStream = fileSystem.open(new Path(sourceFilePath))

val byteArrayOutputStream = new ByteArrayOutputStream()
IOUtils.copyBytes(inputStream, byteArrayOutputStream, 4096, true)

val byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray)

val blockBlobClient = blobServiceClient
.getBlobContainerClient(bucketName)
.getBlobClient(filePath)
.getBlockBlobClient

val streamSize = byteArrayInputStream.available()
blockBlobClient.upload(byteArrayInputStream, streamSize)
if (isFabricLakehouse) {
val abfssPath = s"abfss://$bucketName/$filePath"
val fs = getHadoopFS(abfssPath)
val inputStream = fs.open(new Path(sourceFilePath))
val outputStream = fs.create(new Path(abfssPath), true)
IOUtils.copyBytes(inputStream, outputStream, 4096, true)
} else {
val fileSystem = FileSystem.get(ResourceHelper.spark.sparkContext.hadoopConfiguration)
val inputStream = fileSystem.open(new Path(sourceFilePath))

val byteArrayOutputStream = new ByteArrayOutputStream()
IOUtils.copyBytes(inputStream, byteArrayOutputStream, 4096, true)

val byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray)

val blockBlobClient = blobServiceClient
.getBlobContainerClient(bucketName)
.getBlobClient(filePath)
.getBlockBlobClient

val streamSize = byteArrayInputStream.available()
blockBlobClient.upload(byteArrayInputStream, streamSize)
}
}

override def downloadFilesFromBucketToDirectory(
bucketName: String,
filePath: String,
directoryPath: String,
isIndex: Boolean): Unit = {
try {
val blobContainerClient = blobServiceClient.getBlobContainerClient(bucketName)
val blobOptions = new ListBlobsOptions().setPrefix(filePath)
val blobs = blobContainerClient
.listBlobs(blobOptions, null)
.iterator()
.asScala
.toSeq
if (isFabricLakehouse) {
downloadFilesFromAbfssUri(bucketName, directoryPath)
} else {
try {
val blobContainerClient = blobServiceClient.getBlobContainerClient(bucketName)
val blobOptions = new ListBlobsOptions().setPrefix(filePath)
val blobs = blobContainerClient
.listBlobs(blobOptions, null)
.iterator()
.asScala
.toSeq

if (blobs.isEmpty) {
throw new Exception(
s"Not found blob path $filePath in container $bucketName when downloading files from Azure Blob Storage")
}

if (blobs.isEmpty) {
throw new Exception(
s"Not found blob path $filePath in container $bucketName when downloading files from Azure Blob Storage")
blobs.foreach { blobItem =>
val blobName = blobItem.getName
val blobClient = blobContainerClient.getBlobClient(blobName)

val file = new File(s"$directoryPath/$blobName")
if (blobName.endsWith("/")) {
file.mkdirs()
} else {
file.getParentFile.mkdirs()
val outputStream = new FileOutputStream(file)
blobClient.downloadStream(outputStream)
outputStream.close()
}
}
} catch {
case e: Exception =>
throw new Exception(
"Error when downloading files from Azure Blob Storage: " + e.getMessage)
}
}
}

blobs.foreach { blobItem =>
val blobName = blobItem.getName
val blobClient = blobContainerClient.getBlobClient(blobName)

val file = new File(s"$directoryPath/$blobName")
if (blobName.endsWith("/")) {
file.mkdirs()
} else {
file.getParentFile.mkdirs()
val outputStream = new FileOutputStream(file)
blobClient.downloadStream(outputStream)
outputStream.close()
}
/** Download all files under abfss URI to local directory (Fabric) */
private def downloadFilesFromAbfssUri(uri: String, directory: String): Unit = {
val fs =
FileSystem.get(new Path(uri).toUri, ResourceHelper.spark.sparkContext.hadoopConfiguration)
val files = fs.globStatus(new Path(uri + "/*"))
if (files == null || files.isEmpty) throw new Exception(s"No files found at $uri")
files.foreach { status =>
val fileName = status.getPath.getName
val localFile = new File(s"$directory/$fileName")
if (status.isDirectory) {
localFile.mkdirs()
} else {
localFile.getParentFile.mkdirs()
val out = new FileOutputStream(localFile)
val in = fs.open(status.getPath)
IOUtils.copyBytes(in, out, 4096, true)
out.close()
in.close()
}
} catch {
case e: Exception =>
throw new Exception(
"Error when downloading files from Azure Blob Storage: " + e.getMessage)
}
}

private def doesPathExistAbfss(uri: String): Boolean = {
val path = new Path(uri)
val fs = FileSystem.get(path.toUri, ResourceHelper.spark.sparkContext.hadoopConfiguration)
fs.exists(path)
}

private def copyInputStreamToAbfssUri(uri: String, in: InputStream): Unit = {
val path = new Path(uri)
val fs = FileSystem.get(path.toUri, ResourceHelper.spark.sparkContext.hadoopConfiguration)
val out = fs.create(path, true)
try {
IOUtils.copyBytes(in, out, 4096, true)
} finally {
out.close()
in.close()
}
}

}
Loading