-
Notifications
You must be signed in to change notification settings - Fork 705
Description
I'm trying to use the PJRT C API so I can execute an HLO module, but I don't understand the instructions. The integration guide says I can either implement the C API or implement the C++ API and use it to wrap the C API. The former seems less complicated so that's what I'm trying to do. So far I've built the pjrt_c_api_cpu library with bazel build --strip=never //xla/pjrt/c:pjrt_c_api_cpu and then linked it with the following cmake file:
cmake_minimum_required(VERSION 3.10)
project(execute_hlo CXX)
set(PJRT_C_API_LIB xla/bazel-bin/xla/pjrt/c/libpjrt_c_api_cpu.a)
set(XLA_INCLUDE xla)
add_executable(execute_hlo execute_hlo.cpp)
target_link_libraries(execute_hlo "${CMAKE_SOURCE_DIR}/${PJRT_C_API_LIB}")
target_include_directories(execute_hlo PRIVATE "${CMAKE_SOURCE_DIR}/${XLA_INCLUDE}")
set_target_properties(execute_hlo PROPERTIES RUNTIME_OUTPUT_DIRECTORY /build/)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(CMAKE_BUILD_TYPE Release)
add_custom_command(TARGET execute_hlo POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:execute_hlo> ${CMAKE_SOURCE_DIR}/build/)After that I created execute_hlo.cpp to load my HLO module into a PJRT_Program:
#include <memory>
#include <string>
#include <vector>
#include <fstream>
#include <streambuf>
#include <sstream>
#include <iostream>
#include "xla/pjrt/c/pjrt_c_api.h"
int main(int argc, char** argv)
{
std::ifstream t("hlo_comp_proto.txt");
std::stringstream buffer;
buffer << t.rdbuf();
std::string hlo_code = buffer.str();
std::string format = "hlo";
PJRT_Program pjrt_program;
pjrt_program.code = (char*)hlo_code.c_str();
pjrt_program.code_size = (size_t)hlo_code.size();
pjrt_program.format = format.c_str();
pjrt_program.format_size = (size_t)format.size();
std::cout << "HLO Code:\n\n" << pjrt_program.code << "\n\n";
std::cout << "Code size: " << pjrt_program.code_size << "\n";
std::cout << "Format: " << pjrt_program.format << "\n";
std::cout << "Format size: " << pjrt_program.format_size << "\n";
return 0;
}This works, but I think I've hit a dead end since I don't see any functions that will let me do anything with this struct. After looking around a bit it seems like I'm supposed to use some implementation of GetPjrtApi() that returns a PJRT_Api pointer, but I only see C++ implementations. Maybe that means the C -> C++ wrapper is the simpler approach? If it is, which targets would I have to build to use it? Or if it's not, what am I doing wrong with my current approach?