Skip to content

How to use the PJRT C API? #7038

@hayden-donnelly

Description

@hayden-donnelly

I'm trying to use the PJRT C API so I can execute an HLO module, but I don't understand the instructions. The integration guide says I can either implement the C API or implement the C++ API and use it to wrap the C API. The former seems less complicated so that's what I'm trying to do. So far I've built the pjrt_c_api_cpu library with bazel build --strip=never //xla/pjrt/c:pjrt_c_api_cpu and then linked it with the following cmake file:

cmake_minimum_required(VERSION 3.10)
project(execute_hlo CXX)

set(PJRT_C_API_LIB xla/bazel-bin/xla/pjrt/c/libpjrt_c_api_cpu.a)
set(XLA_INCLUDE xla)

add_executable(execute_hlo execute_hlo.cpp)
target_link_libraries(execute_hlo "${CMAKE_SOURCE_DIR}/${PJRT_C_API_LIB}")
target_include_directories(execute_hlo PRIVATE "${CMAKE_SOURCE_DIR}/${XLA_INCLUDE}")

set_target_properties(execute_hlo PROPERTIES RUNTIME_OUTPUT_DIRECTORY /build/)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(CMAKE_BUILD_TYPE Release)

add_custom_command(TARGET execute_hlo POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:execute_hlo> ${CMAKE_SOURCE_DIR}/build/)

After that I created execute_hlo.cpp to load my HLO module into a PJRT_Program:

#include <memory>
#include <string>
#include <vector>
#include <fstream>
#include <streambuf>
#include <sstream> 
#include <iostream>
#include "xla/pjrt/c/pjrt_c_api.h"

int main(int argc, char** argv) 
{
    std::ifstream t("hlo_comp_proto.txt");
    std::stringstream buffer;
    buffer << t.rdbuf();
    std::string hlo_code = buffer.str();
    std::string format = "hlo";
    
    PJRT_Program pjrt_program;
    pjrt_program.code = (char*)hlo_code.c_str();
    pjrt_program.code_size = (size_t)hlo_code.size();
    pjrt_program.format = format.c_str();
    pjrt_program.format_size = (size_t)format.size();

    std::cout << "HLO Code:\n\n" << pjrt_program.code << "\n\n";
    std::cout << "Code size: " << pjrt_program.code_size << "\n";
    std::cout << "Format: " << pjrt_program.format << "\n";
    std::cout << "Format size: " << pjrt_program.format_size << "\n";
    return 0;
}

This works, but I think I've hit a dead end since I don't see any functions that will let me do anything with this struct. After looking around a bit it seems like I'm supposed to use some implementation of GetPjrtApi() that returns a PJRT_Api pointer, but I only see C++ implementations. Maybe that means the C -> C++ wrapper is the simpler approach? If it is, which targets would I have to build to use it? Or if it's not, what am I doing wrong with my current approach?

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions