Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
}

MTL::Library* load_default_library(MTL::Device* device) {
NS::Error* error[4];
NS::Error* error[5];
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
Expand All @@ -127,12 +127,19 @@ MTL::Library* load_default_library(MTL::Device* device) {
return lib;
}

// Try lo load resources from Framework resources if SwiftPM wrapped as a
// dynamic framework.
std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
if (lib) {
return lib;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_colocated_library will get the directory of the binary that contains the get_binary_directory(). In the case of an embedded .framework it would be something like x.app/Contents/Frameworks/MLX.framework/Versions/A. This would then append Resources/default and .metallib onto that and load the file.

I think this works if the default.metallib is hoisted out of the mlx-swift_Cmlx.bundle/Contents/Resources. I am not sure it is because I don't have a .framework wrapping the MLX swiftpm library myself, but such a thing could surely be done.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, I am surprised that this does not do the same thing:

std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
    MTL::Device* device,
    const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
...
  auto bundles = NS::Bundle::allBundles();
  for (int i = 0, c = (int)bundles->count(); i < c; i++) {
    auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
    library = try_load_bundle(device, bundle->resourceURL(), lib_name);

That should scan all bundles (frameworks are considered bundles) looking for a default.metallib.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work this way, because bundles doesn't contain MLX framework path. In current configuration, it always looks for for main bundle and tries to load default.metallib from mlx-swift_Cmlx.bundle. But with dynamic frameworks, app directory looks next way:

  • App.app
    • Contents
      • Frameworks
        • MLX
          • Resources
            • default.metallib
      • Resources
NSBundle <~Work/DerivedData/Eney-dyyiohebtzyjuwbwalhqtrcvkvcn/Build/Products/Debug/App.app> (loaded),
NSBundle </System/Library/Extensions/AGXMetalG13X.bundle> (loaded),
NSBundle </System/Library/Input Methods/PressAndHold.app/Contents/PlugIns/PAH_Extension.appex> (not yet loaded),
NSBundle </System/Library/CoreServices/SystemAppearance.bundle> (not yet loaded)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkoski I've covered a questions related to load_swiftpm_library above

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, if the default.metallib is in the Resources directory, I agree this looks good.


// Finally try default_mtllib_path
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
for (int i = 0; i < 4; i++) {
for (int i = 0; i < 5; i++) {
if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " ";
}
Expand Down