From d91febd744a21cbce3eb223b5aa26ad0d2d57ff6 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Tue, 1 Jul 2025 16:26:45 -0500 Subject: [PATCH] [SPARKNLP-1215] Updating support for Microsoft Fabric to allow download models in lakehouse containers --- .../johnsnowlabs/client/CloudResources.scala | 80 ++++++-- .../client/azure/AzureClient.scala | 18 +- .../client/azure/AzureGateway.scala | 193 ++++++++++++------ .../client/util/CloudHelper.scala | 12 +- .../nlp/pretrained/S3ResourceDownloader.scala | 24 ++- 5 files changed, 230 insertions(+), 97 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/client/CloudResources.scala b/src/main/scala/com/johnsnowlabs/client/CloudResources.scala index 322f459e952e9e..c03022fbb7c484 100644 --- a/src/main/scala/com/johnsnowlabs/client/CloudResources.scala +++ b/src/main/scala/com/johnsnowlabs/client/CloudResources.scala @@ -79,9 +79,9 @@ object CloudResources { if (!modelExists) { val destination = unzipInExternalCloudStorage(sourceS3URI, cachePath, azureClient, zippedModel) - modelPath = Some(transformURIToWASB(destination)) + modelPath = buildAzureModelPath(cachePath, destination) } else { - modelPath = Some(transformURIToWASB(cachePath + "/" + modelName)) + modelPath = buildAzureModelPathWhenExists(cachePath, modelName) } modelPath @@ -90,6 +90,24 @@ object CloudResources { } + private def buildAzureModelPath(cachePath: String, destination: String): Option[String] = { + if (CloudHelper.isFabricAbfss(cachePath)) { + Some(destination) + } else { + Some(transformURIToWASB(destination)) + } + } + + private def buildAzureModelPathWhenExists( + cachePath: String, + modelName: String): Option[String] = { + if (CloudHelper.isFabricAbfss(cachePath)) { + Some(cachePath + "/" + modelName) + } else { + Some(transformURIToWASB(cachePath + "/" + modelName)) + } + } + private def doesModelExistInExternalCloudStorage( modelName: String, destinationURI: String, @@ -111,11 +129,18 @@ object CloudResources { gcpClient.doesBucketPathExist(destinationBucketName, modelPath) } case azureClient: AzureClient => { - val (destinationBucketName, destinationStoragePath) = - CloudHelper.parseAzureBlobURI(destinationURI) - val modelPath = destinationStoragePath + "/" + modelName + if (CloudHelper.isFabricAbfss(destinationURI)) { + val fabricUri = + if (destinationURI.endsWith("/")) destinationURI + modelName + else destinationURI + "/" + modelName + azureClient.doesBucketPathExist(fabricUri, "") + } else { + val (destinationBucketName, destinationStoragePath) = + CloudHelper.parseAzureBlobURI(destinationURI) + val modelPath = destinationStoragePath + "/" + modelName - azureClient.doesBucketPathExist(destinationBucketName, modelPath) + azureClient.doesBucketPathExist(destinationBucketName, modelPath) + } } } @@ -133,7 +158,6 @@ object CloudResources { val zipFile = sourceKey.split("/").last val modelName = zipFile.substring(0, zipFile.indexOf(".zip")) - println(s"Uploading model $modelName to external Cloud Storage URI: $destinationStorageURI") while (zipEntry != null) { if (!zipEntry.isDirectory) { val outputStream = new ByteArrayOutputStream() @@ -165,16 +189,23 @@ object CloudResources { inputStream) } case azureClient: AzureClient => { - val (destinationBucketName, destinationStoragePath) = - CloudHelper.parseAzureBlobURI(destinationStorageURI) - - val destinationAzureStoragePath = - s"$destinationStoragePath/$modelName/${zipEntry.getName}".stripPrefix("/") - - azureClient.copyFileToBucket( - destinationBucketName, - destinationAzureStoragePath, - inputStream) + if (CloudHelper.isFabricAbfss(destinationStorageURI)) { + val fabricUri = (if (destinationStorageURI.endsWith("/")) destinationStorageURI + else destinationStorageURI + "/") + + s"$modelName/${zipEntry.getName}" + azureClient.copyFileToBucket(fabricUri, "", inputStream) + } else { + val (destinationBucketName, destinationStoragePath) = + CloudHelper.parseAzureBlobURI(destinationStorageURI) + + val destinationAzureStoragePath = + s"$destinationStoragePath/$modelName/${zipEntry.getName}".stripPrefix("/") + + azureClient.copyFileToBucket( + destinationBucketName, + destinationAzureStoragePath, + inputStream) + } } } @@ -286,9 +317,18 @@ object CloudResources { Paths.get(directory, keyPrefix).toUri } case azureClient: AzureClient => { - val (bucketName, keyPrefix) = CloudHelper.parseAzureBlobURI(bucketURI) - azureClient.downloadFilesFromBucketToDirectory(bucketName, keyPrefix, directory, isIndex) - Paths.get(directory, keyPrefix).toUri + if (CloudHelper.isFabricAbfss(bucketURI)) { + azureClient.downloadFilesFromBucketToDirectory(bucketURI, "", directory, isIndex) + Paths.get(directory).toUri + } else { + val (bucketName, keyPrefix) = CloudHelper.parseAzureBlobURI(bucketURI) + azureClient.downloadFilesFromBucketToDirectory( + bucketName, + keyPrefix, + directory, + isIndex) + Paths.get(directory, keyPrefix).toUri + } } } diff --git a/src/main/scala/com/johnsnowlabs/client/azure/AzureClient.scala b/src/main/scala/com/johnsnowlabs/client/azure/AzureClient.scala index cf05dc68c43e6e..9f66ad9c853160 100644 --- a/src/main/scala/com/johnsnowlabs/client/azure/AzureClient.scala +++ b/src/main/scala/com/johnsnowlabs/client/azure/AzureClient.scala @@ -1,5 +1,6 @@ package com.johnsnowlabs.client.azure +import com.johnsnowlabs.client.util.CloudHelper import com.johnsnowlabs.client.{CloudClient, CloudStorage} import com.johnsnowlabs.util.ConfigHelper @@ -8,12 +9,17 @@ class AzureClient(parameters: Map[String, String] = Map.empty) extends CloudClie private lazy val azureStorageConnection = cloudConnect() override protected def cloudConnect(): CloudStorage = { - val storageAccountName = parameters.getOrElse( - "storageAccountName", - throw new Exception("Azure client requires storageAccountName")) - val accountKey = - parameters.getOrElse("accountKey", ConfigHelper.getHadoopAzureConfig(storageAccountName)) - new AzureGateway(storageAccountName, accountKey) + if (CloudHelper.isMicrosoftFabric) { + // These params are NOT required for Fabric + new AzureGateway("", "", isFabricLakehouse = true) + } else { + val storageAccountName = parameters.getOrElse( + "storageAccountName", + throw new Exception("Azure client requires storageAccountName")) + val accountKey = + parameters.getOrElse("accountKey", ConfigHelper.getHadoopAzureConfig(storageAccountName)) + new AzureGateway(storageAccountName, accountKey) + } } override def doesBucketPathExist(bucketName: String, filePath: String): Boolean = { diff --git a/src/main/scala/com/johnsnowlabs/client/azure/AzureGateway.scala b/src/main/scala/com/johnsnowlabs/client/azure/AzureGateway.scala index 527c890655a8b0..21d8245df3044c 100644 --- a/src/main/scala/com/johnsnowlabs/client/azure/AzureGateway.scala +++ b/src/main/scala/com/johnsnowlabs/client/azure/AzureGateway.scala @@ -7,10 +7,14 @@ import com.johnsnowlabs.nlp.util.io.ResourceHelper import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.IOUtils -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, InputStream} +import java.io._ import scala.jdk.CollectionConverters.asScalaIteratorConverter -class AzureGateway(storageAccountName: String, accountKey: String) extends CloudStorage { +class AzureGateway( + storageAccountName: String, + accountKey: String, + isFabricLakehouse: Boolean = false) + extends CloudStorage { private lazy val blobServiceClient: BlobServiceClient = { val connectionString = @@ -23,54 +27,74 @@ class AzureGateway(storageAccountName: String, accountKey: String) extends Cloud blobServiceClient } + private def getHadoopFS(path: String): FileSystem = { + val uri = new java.net.URI(path) + FileSystem.get(uri, ResourceHelper.spark.sparkContext.hadoopConfiguration) + } + override def doesBucketPathExist(bucketName: String, filePath: String): Boolean = { - val blobContainerClient = blobServiceClient - .getBlobContainerClient(bucketName) + if (isFabricLakehouse) { + doesPathExistAbfss(bucketName) + } else { + val blobContainerClient = blobServiceClient + .getBlobContainerClient(bucketName) - val prefix = if (filePath.endsWith("/")) filePath else filePath + "/" + val prefix = if (filePath.endsWith("/")) filePath else filePath + "/" - val blobs = blobContainerClient - .listBlobs() - .iterator() - .asScala - .filter(_.getName.startsWith(prefix)) + val blobs = blobContainerClient + .listBlobs() + .iterator() + .asScala + .filter(_.getName.startsWith(prefix)) - blobs.nonEmpty + blobs.nonEmpty + } } override def copyFileToBucket( bucketName: String, destinationPath: String, inputStream: InputStream): Unit = { - - val blockBlobClient = blobServiceClient - .getBlobContainerClient(bucketName) - .getBlobClient(destinationPath) - .getBlockBlobClient - - val streamSize = inputStream.available() - blockBlobClient.upload(inputStream, streamSize) + if (isFabricLakehouse) { + copyInputStreamToAbfssUri(bucketName, inputStream) + } else { + val blockBlobClient = blobServiceClient + .getBlobContainerClient(bucketName) + .getBlobClient(destinationPath) + .getBlockBlobClient + + val streamSize = inputStream.available() + blockBlobClient.upload(inputStream, streamSize) + } } override def copyInputStreamToBucket( bucketName: String, filePath: String, sourceFilePath: String): Unit = { - val fileSystem = FileSystem.get(ResourceHelper.spark.sparkContext.hadoopConfiguration) - val inputStream = fileSystem.open(new Path(sourceFilePath)) - - val byteArrayOutputStream = new ByteArrayOutputStream() - IOUtils.copyBytes(inputStream, byteArrayOutputStream, 4096, true) - - val byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray) - - val blockBlobClient = blobServiceClient - .getBlobContainerClient(bucketName) - .getBlobClient(filePath) - .getBlockBlobClient - - val streamSize = byteArrayInputStream.available() - blockBlobClient.upload(byteArrayInputStream, streamSize) + if (isFabricLakehouse) { + val abfssPath = s"abfss://$bucketName/$filePath" + val fs = getHadoopFS(abfssPath) + val inputStream = fs.open(new Path(sourceFilePath)) + val outputStream = fs.create(new Path(abfssPath), true) + IOUtils.copyBytes(inputStream, outputStream, 4096, true) + } else { + val fileSystem = FileSystem.get(ResourceHelper.spark.sparkContext.hadoopConfiguration) + val inputStream = fileSystem.open(new Path(sourceFilePath)) + + val byteArrayOutputStream = new ByteArrayOutputStream() + IOUtils.copyBytes(inputStream, byteArrayOutputStream, 4096, true) + + val byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray) + + val blockBlobClient = blobServiceClient + .getBlobContainerClient(bucketName) + .getBlobClient(filePath) + .getBlockBlobClient + + val streamSize = byteArrayInputStream.available() + blockBlobClient.upload(byteArrayInputStream, streamSize) + } } override def downloadFilesFromBucketToDirectory( @@ -78,38 +102,83 @@ class AzureGateway(storageAccountName: String, accountKey: String) extends Cloud filePath: String, directoryPath: String, isIndex: Boolean): Unit = { - try { - val blobContainerClient = blobServiceClient.getBlobContainerClient(bucketName) - val blobOptions = new ListBlobsOptions().setPrefix(filePath) - val blobs = blobContainerClient - .listBlobs(blobOptions, null) - .iterator() - .asScala - .toSeq + if (isFabricLakehouse) { + downloadFilesFromAbfssUri(bucketName, directoryPath) + } else { + try { + val blobContainerClient = blobServiceClient.getBlobContainerClient(bucketName) + val blobOptions = new ListBlobsOptions().setPrefix(filePath) + val blobs = blobContainerClient + .listBlobs(blobOptions, null) + .iterator() + .asScala + .toSeq + + if (blobs.isEmpty) { + throw new Exception( + s"Not found blob path $filePath in container $bucketName when downloading files from Azure Blob Storage") + } - if (blobs.isEmpty) { - throw new Exception( - s"Not found blob path $filePath in container $bucketName when downloading files from Azure Blob Storage") + blobs.foreach { blobItem => + val blobName = blobItem.getName + val blobClient = blobContainerClient.getBlobClient(blobName) + + val file = new File(s"$directoryPath/$blobName") + if (blobName.endsWith("/")) { + file.mkdirs() + } else { + file.getParentFile.mkdirs() + val outputStream = new FileOutputStream(file) + blobClient.downloadStream(outputStream) + outputStream.close() + } + } + } catch { + case e: Exception => + throw new Exception( + "Error when downloading files from Azure Blob Storage: " + e.getMessage) } + } + } - blobs.foreach { blobItem => - val blobName = blobItem.getName - val blobClient = blobContainerClient.getBlobClient(blobName) - - val file = new File(s"$directoryPath/$blobName") - if (blobName.endsWith("/")) { - file.mkdirs() - } else { - file.getParentFile.mkdirs() - val outputStream = new FileOutputStream(file) - blobClient.downloadStream(outputStream) - outputStream.close() - } + /** Download all files under abfss URI to local directory (Fabric) */ + private def downloadFilesFromAbfssUri(uri: String, directory: String): Unit = { + val fs = + FileSystem.get(new Path(uri).toUri, ResourceHelper.spark.sparkContext.hadoopConfiguration) + val files = fs.globStatus(new Path(uri + "/*")) + if (files == null || files.isEmpty) throw new Exception(s"No files found at $uri") + files.foreach { status => + val fileName = status.getPath.getName + val localFile = new File(s"$directory/$fileName") + if (status.isDirectory) { + localFile.mkdirs() + } else { + localFile.getParentFile.mkdirs() + val out = new FileOutputStream(localFile) + val in = fs.open(status.getPath) + IOUtils.copyBytes(in, out, 4096, true) + out.close() + in.close() } - } catch { - case e: Exception => - throw new Exception( - "Error when downloading files from Azure Blob Storage: " + e.getMessage) } } + + private def doesPathExistAbfss(uri: String): Boolean = { + val path = new Path(uri) + val fs = FileSystem.get(path.toUri, ResourceHelper.spark.sparkContext.hadoopConfiguration) + fs.exists(path) + } + + private def copyInputStreamToAbfssUri(uri: String, in: InputStream): Unit = { + val path = new Path(uri) + val fs = FileSystem.get(path.toUri, ResourceHelper.spark.sparkContext.hadoopConfiguration) + val out = fs.create(path, true) + try { + IOUtils.copyBytes(in, out, 4096, true) + } finally { + out.close() + in.close() + } + } + } diff --git a/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala b/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala index 8ab06e13e08c2e..35faeac2719725 100644 --- a/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala +++ b/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala @@ -71,8 +71,7 @@ object CloudHelper { } def isCloudPath(uri: String): Boolean = { - val intraCloudPath = isIntraCloudPath(uri) - (isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)) && !intraCloudPath + isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri) } def isS3Path(uri: String): Boolean = { @@ -83,17 +82,16 @@ object CloudHelper { private def isAzureBlobPath(uri: String): Boolean = { (uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")) || uri.startsWith( - "abfss://") - } - - private def isIntraCloudPath(uri: String): Boolean = { - uri.startsWith("abfss://") && isMicrosoftFabric + "abfss://") } def isMicrosoftFabric: Boolean = { ResourceHelper.spark.conf.getAll.keys.exists(_.startsWith("spark.fabric")) } + def isFabricAbfss(uri: String): Boolean = + uri.startsWith("abfss://") && uri.contains("onelake.dfs.fabric.microsoft.com") + def cloudType(uri: String): CloudStorageType = { if (isS3Path(uri)) { CloudStorageType.S3 diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala index ef338e1811f317..4255a3203fe597 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala @@ -21,6 +21,7 @@ import com.johnsnowlabs.client.aws.AWSGateway import com.johnsnowlabs.client.util.CloudHelper import com.johnsnowlabs.util.FileHelper import org.apache.hadoop.fs.Path +import org.slf4j.{Logger, LoggerFactory} import java.io.File import java.nio.file.Files @@ -35,11 +36,25 @@ class S3ResourceDownloader( region: String = "us-east-1") extends ResourceDownloader { + private val logger: Logger = LoggerFactory.getLogger(this.getClass.toString) + private val repoFolder2Metadata: mutable.Map[String, RepositoryMetadata] = mutable.Map[String, RepositoryMetadata]() val cachePath = new Path(cacheFolder) - if (!CloudHelper.isCloudPath(cacheFolder) && !fileSystem.exists(cachePath)) { + private val isNotCloudPath = !CloudHelper.isCloudPath(cacheFolder) + + private lazy val doesNotExistCachePath = { + try { + !fileSystem.exists(cachePath) + } catch { + case e: Exception => + logger.error(s"Error checking cache path existence: ${e.getMessage}") + false + } + } + + if (isNotCloudPath && doesNotExistCachePath) { fileSystem.mkdirs(cachePath) } @@ -83,16 +98,20 @@ class S3ResourceDownloader( val link = resolveLink(request) link.flatMap { resource => val s3FilePath = awsGateway.getS3File(s3Path, request.folder, resource.fileName) - + logger.info(s"In S3ResourceDownloader.download: $s3FilePath") if (!awsGateway.doesS3ObjectExist(bucket, s3FilePath)) { + logger.info("Resource not found in S3") None } else { + logger.info("Resource found in S3") val sourceS3URI = s"s3a://$bucket/$s3FilePath" val zipFile = sourceS3URI.split("/").last val modelName = zipFile.substring(0, zipFile.indexOf(".zip")) + logger.info("Before cachePath.toString: " + cachePath.toString) cachePath.toString match { case path if CloudHelper.isCloudPath(path) => { + logger.info(s"In S3ResourceDownloader.cachePath is cloud path: $path") CloudResources.downloadModelFromCloud( awsGateway, cachePath.toString, @@ -100,6 +119,7 @@ class S3ResourceDownloader( sourceS3URI) } case _ => { + logger.info(s"In S3ResourceDownloader before downloadAndUnzipFile") val destinationFile = new Path(cachePath.toString, resource.fileName) downloadAndUnzipFile(destinationFile, resource, s3FilePath) }