Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
548 changes: 236 additions & 312 deletions include/af/oneapi.h

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/backend/oneapi/device_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ DeviceManager::DeviceManager()
make_unique<sycl::queue>(*mContexts.back(), *devices[i],
arrayfire_exception_handler));
mIsGLSharingOn.push_back(false);
// TODO:
// mDeviceTypes.push_back(getDeviceTypeEnum(*devices[i]));
// mPlatforms.push_back(getPlatformEnum(*devices[i]));

mDeviceTypes.push_back(
devices[i]->get_info<sycl::info::device::device_type>());
mPlatforms.push_back(devices[i]->get_platform());
mDevices.emplace_back(std::move(devices[i]));

std::string options;
Expand Down
15 changes: 7 additions & 8 deletions src/backend/oneapi/device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DeviceManager {

friend int getDeviceCount() noexcept;

// friend int getDeviceIdFromNativeId(cl_device_id id);
friend int getDeviceIdFromNativeDevice(sycl::device dev);

friend const sycl::context& getContext();

Expand All @@ -101,16 +101,15 @@ class DeviceManager {

friend int setDevice(int device);

friend void addDeviceContext(sycl::device& dev, sycl::context& ctx,
sycl::queue& que);
friend void addDeviceContext(sycl::queue& que);

friend void setDeviceContext(sycl::device& dev, sycl::context& ctx);

friend void removeDeviceContext(sycl::device& dev, sycl::context& ctx);
friend void removeDevice(sycl::device& dev);

friend int getActiveDeviceType();
friend sycl::info::device_type getActiveDeviceType();

friend int getActivePlatform();
friend sycl::platform getActivePlatform();

public:
static const int MAX_DEVICES = 32;
Expand Down Expand Up @@ -141,8 +140,8 @@ class DeviceManager {
std::vector<std::unique_ptr<sycl::queue>> mQueues;
std::vector<bool> mIsGLSharingOn;
std::vector<std::string> mBaseOpenCLBuildFlags;
std::vector<int> mDeviceTypes;
std::vector<int> mPlatforms;
std::vector<sycl::info::device_type> mDeviceTypes;
std::vector<sycl::platform> mPlatforms;
unsigned mUserDeviceOffset;

std::unique_ptr<arrayfire::common::ForgeManager> fgMngr;
Expand Down
4 changes: 2 additions & 2 deletions src/backend/oneapi/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::string genPlanHashStr(int rank, ::oneapi::mkl::dft::precision precision,

if (istrides != nullptr) {
for (int r = 0; r < rank + 1; ++r) {
sprintf(key_str_temp, "%ld:", istrides[r]);
sprintf(key_str_temp, "%lld:", istrides[r]);
key_string.append(std::string(key_str_temp));
}
sprintf(key_str_temp, "%d:", ibatch);
Expand All @@ -76,7 +76,7 @@ std::string genPlanHashStr(int rank, ::oneapi::mkl::dft::precision precision,

if (ostrides != nullptr) {
for (int r = 0; r < rank + 1; ++r) {
sprintf(key_str_temp, "%ld:", ostrides[r]);
sprintf(key_str_temp, "%lld:", ostrides[r]);
key_string.append(std::string(key_str_temp));
}
sprintf(key_str_temp, "%d:", obatch);
Expand Down
1 change: 1 addition & 0 deletions src/backend/oneapi/kernel/sort_by_key_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// temporary ignores for DPL internals
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#endif

// oneDPL headers should be included before standard headers
Expand Down
113 changes: 41 additions & 72 deletions src/backend/oneapi/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,6 @@ inline string platformMap(string& platStr) {
}
}

af_oneapi_platform getPlatformEnum(sycl::device dev) {
string pname = getPlatformName(dev);
if (verify_present(pname, "AMD"))
return AF_ONEAPI_PLATFORM_AMD;
else if (verify_present(pname, "NVIDIA"))
return AF_ONEAPI_PLATFORM_NVIDIA;
else if (verify_present(pname, "INTEL"))
return AF_ONEAPI_PLATFORM_INTEL;
else if (verify_present(pname, "APPLE"))
return AF_ONEAPI_PLATFORM_APPLE;
else if (verify_present(pname, "BEIGNET"))
return AF_ONEAPI_PLATFORM_BEIGNET;
else if (verify_present(pname, "POCL"))
return AF_ONEAPI_PLATFORM_POCL;
return AF_ONEAPI_PLATFORM_UNKNOWN;
}

string getDeviceInfo() noexcept {
ostringstream info;
info << "ArrayFire v" << AF_VERSION << " (oneAPI, " << get_system()
Expand All @@ -150,8 +133,6 @@ string getDeviceInfo() noexcept {
common::lock_guard_t lock(devMngr.deviceMutex);
unsigned nDevices = 0;
for (auto& device : devMngr.mDevices) {
// const Platform platform(device->getInfo<CL_DEVICE_PLATFORM>());

string dstr = device->get_info<sycl::info::device::name>();
bool show_braces =
(static_cast<unsigned>(getActiveDeviceId()) == nDevices);
Expand Down Expand Up @@ -233,24 +214,18 @@ unsigned getActiveDeviceId() {
return get<1>(tlocalActiveDeviceId());
}

/*
int getDeviceIdFromNativeId(cl_device_id id) {
DeviceManager& devMngr = DeviceManager::getInstance();
int getDeviceIdFromNativeDevice(sycl::device dev) {
DeviceManager& mngr = DeviceManager::getInstance();

common::lock_guard_t lock(devMngr.deviceMutex);
common::lock_guard_t lock(mngr.deviceMutex);

int nDevices = static_cast<int>(devMngr.mDevices.size());
int devId = 0;
for (devId = 0; devId < nDevices; ++devId) {
//TODO: how to get cl_device_id from sycl::device
if (id == devMngr.mDevices[devId]->get()) { return devId; }
for (int devId = 0; devId < mngr.mDevices.size(); ++devId) {
if (dev == *mngr.mDevices[devId]) { return devId; }
}
// TODO: reasonable if no match??
return -1;
}
*/

int getActiveDeviceType() {
sycl::info::device_type getActiveDeviceType() {
device_id_t& devId = tlocalActiveDeviceId();

DeviceManager& devMngr = DeviceManager::getInstance();
Expand All @@ -260,7 +235,7 @@ int getActiveDeviceType() {
return devMngr.mDeviceTypes[get<1>(devId)];
}

int getActivePlatform() {
sycl::platform getActivePlatform() {
device_id_t& devId = tlocalActiveDeviceId();

DeviceManager& devMngr = DeviceManager::getInstance();
Expand Down Expand Up @@ -415,23 +390,25 @@ void sync(int device) {
setDevice(currDevice);
}

void addDeviceContext(sycl::device& dev, sycl::context& ctx, sycl::queue& que) {
void addDeviceContext(sycl::queue& que) {
DeviceManager& devMngr = DeviceManager::getInstance();

int nDevices = 0;
{
sycl::context ctx = que.get_context();
sycl::device dev = que.get_device();
common::lock_guard_t lock(devMngr.deviceMutex);

auto tDevice = make_unique<sycl::device>(dev);
auto tContext = make_unique<sycl::context>(ctx);
// queue atleast has implicit context and device if created
auto tQueue = make_unique<sycl::queue>(que);

devMngr.mPlatforms.push_back(getPlatformEnum(*tDevice));
devMngr.mPlatforms.push_back(tDevice->get_platform());
// FIXME: add OpenGL Interop for user provided contexts later
devMngr.mIsGLSharingOn.push_back(false);
devMngr.mDeviceTypes.push_back(static_cast<int>(
tDevice->get_info<sycl::info::device::device_type>()));
devMngr.mDeviceTypes.push_back(
tDevice->get_info<sycl::info::device::device_type>());

devMngr.mDevices.push_back(move(tDevice));
devMngr.mContexts.push_back(move(tContext));
Expand Down Expand Up @@ -461,8 +438,8 @@ void setDeviceContext(sycl::device& dev, sycl::context& ctx) {
AF_ERROR("No matching device found", AF_ERR_ARG);
}

void removeDeviceContext(sycl::device& dev, sycl::context& ctx) {
if (getDevice() == dev && getContext() == ctx) {
void removeDevice(sycl::device& dev) {
if (getDevice() == dev) {
AF_ERROR("Cannot pop the device currently in use", AF_ERR_ARG);
}

Expand All @@ -474,7 +451,7 @@ void removeDeviceContext(sycl::device& dev, sycl::context& ctx) {

const int dCount = static_cast<int>(devMngr.mDevices.size());
for (int i = 0; i < dCount; ++i) {
if (*devMngr.mDevices[i] == dev && *devMngr.mContexts[i] == ctx) {
if (*devMngr.mDevices[i] == dev) {
deleteIdx = i;
break;
}
Expand All @@ -501,6 +478,7 @@ void removeDeviceContext(sycl::device& dev, sycl::context& ctx) {
devMngr.mContexts.erase(devMngr.mContexts.begin() + deleteIdx);
devMngr.mQueues.erase(devMngr.mQueues.begin() + deleteIdx);
devMngr.mPlatforms.erase(devMngr.mPlatforms.begin() + deleteIdx);
devMngr.mDeviceTypes.erase(devMngr.mDeviceTypes.begin() + deleteIdx);

// FIXME: add OpenGL Interop for user provided contexts later
devMngr.mIsGLSharingOn.erase(devMngr.mIsGLSharingOn.begin() +
Expand Down Expand Up @@ -648,84 +626,75 @@ PlanCache& fftManager() { return *oneFFTManager(getActiveDeviceId()); }
} // namespace oneapi
} // namespace arrayfire

/*
//TODO: select which external api functions to expose and add to
header+implement

using namespace oneapi;

af_err afcl_get_device_type(afcl_device_type* res) {
try {
*res = static_cast<afcl_device_type>(getActiveDeviceType());
}
CATCHALL;
return AF_SUCCESS;
}
using namespace arrayfire::oneapi;

af_err afcl_get_platform(afcl_platform* res) {
af_err afoneapi_get_device_type(af_sycl_device_type res) {
try {
*res = static_cast<afcl_platform>(getActivePlatform());
*static_cast<sycl::info::device_type*>(res) = getActiveDeviceType();
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_get_context(cl_context* ctx, const bool retain) {
af_err afoneapi_get_platform(af_sycl_platform res) {
try {
*ctx = getContext()();
if (retain) { clRetainContext(*ctx); }
*static_cast<sycl::platform*>(res) = getActivePlatform();
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_get_queue(cl_command_queue* queue, const bool retain) {
af_err afoneapi_get_context(af_sycl_context ctx) {
try {
*queue = getQueue()();
if (retain) { clRetainCommandQueue(*queue); }
*static_cast<sycl::context*>(ctx) = getContext();
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_get_device_id(cl_device_id* id) {
af_err afoneapi_get_queue(af_sycl_queue queue) {
try {
*id = getDevice()();
*static_cast<sycl::queue*>(queue) = getQueue();
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_set_device_id(cl_device_id id) {
af_err afoneapi_get_device(af_sycl_device dev) {
try {
setDevice(getDeviceIdFromNativeId(id));
*static_cast<sycl::device*>(dev) = getDevice();
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_add_device_context(cl_device_id dev, cl_context ctx,
cl_command_queue que) {
af_err afoneapi_set_device(af_sycl_device dev) {
try {
addDeviceContext(dev, ctx, que);
int devId =
getDeviceIdFromNativeDevice(*static_cast<sycl::device*>(dev));
if (devId != -1) {
setDevice(devId);
} else {
sycl::queue que(*static_cast<sycl::device*>(dev));
addDeviceContext(que);
}
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_set_device_context(cl_device_id dev, cl_context ctx) {
af_err afoneapi_add_queue(af_sycl_queue que) {
try {
setDeviceContext(dev, ctx);
addDeviceContext(*static_cast<sycl::queue*>(que));
}
CATCHALL;
return AF_SUCCESS;
}

af_err afcl_delete_device_context(cl_device_id dev, cl_context ctx) {
af_err afoneapi_delete_device(af_sycl_device dev) {
try {
removeDeviceContext(dev, ctx);
removeDevice(*static_cast<sycl::device*>(dev));
}
CATCHALL;
return AF_SUCCESS;
}
*/
8 changes: 4 additions & 4 deletions src/backend/oneapi/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,19 @@ std::string getPlatformName(const sycl::device& device);

int setDevice(int device);

void addDeviceContext(sycl::device& dev, sycl::context& ctx, sycl::queue& que);
void addDeviceContext(sycl::queue& que);

void setDeviceContext(sycl::device& dev, sycl::context& ctx);

void removeDeviceContext(sycl::device& dev, sycl::context& ctx);
void removeDevice(sycl::device& dev);

void sync(int device);

bool synchronize_calls();

int getActiveDeviceType();
sycl::info::device_type getActiveDeviceType();

int getActivePlatform();
sycl::platform getActivePlatform();

bool& evalFlag();

Expand Down
33 changes: 32 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ target_link_libraries(arrayfire_test
# 'BACKENDS' Backends to target for this test. If not set then the test will
# compiled againat all backends
function(make_test)
set(options CXX11 SERIAL USE_MMIO NO_ARRAYFIRE_TEST)
set(options CXX11 CXX17 SERIAL USE_MMIO NO_ARRAYFIRE_TEST)
set(single_args SRC)
set(multi_args LIBRARIES DEFINITIONS BACKENDS)
cmake_parse_arguments(mt_args "${options}" "${single_args}" "${multi_args}" ${ARGN})
Expand Down Expand Up @@ -238,6 +238,11 @@ function(make_test)
PROPERTIES
CXX_STANDARD 11)
endif(${mt_args_CXX11})
if(${mt_args_CXX17})
set_target_properties(${target}
PROPERTIES
CXX_STANDARD 17)
endif(${mt_args_CXX17})

set_target_properties(${target}
PROPERTIES
Expand Down Expand Up @@ -370,6 +375,32 @@ if(OpenCL_FOUND)
CXX11)
endif()

if(AF_BUILD_ONEAPI)
make_test(SRC interop_sycl_external_context_snippet.cpp
LIBRARIES -fsycl
BACKENDS "oneapi"
NO_ARRAYFIRE_TEST
CXX17)
target_compile_options(test_interop_sycl_external_context_snippet_oneapi
PUBLIC
-fsycl)
make_test(SRC interop_sycl_custom_kernel_snippet.cpp
LIBRARIES -fsycl
BACKENDS "oneapi"
NO_ARRAYFIRE_TEST
CXX17)
target_compile_options(test_interop_sycl_custom_kernel_snippet_oneapi
PUBLIC
-fsycl)
make_test(SRC interop_sycl.cpp
LIBRARIES -fsycl
BACKENDS "oneapi"
CXX17)
target_compile_options(test_interop_sycl_oneapi
PUBLIC
-fsycl)
endif()

if(CUDA_FOUND)
include(AFcuda_helpers)
foreach(backend ${enabled_backends})
Expand Down
Loading