Ouadie EL FAROUKI commited on
Commit
64976cd
·
1 Parent(s): f14c1ad

Updated SYCL device filtering (llama/8901)

Browse files

* Updated device filter to depend on default_selector (fixes non-intel device issues)
* Small related update to example/sycl Readme

Files changed (1) hide show
  1. ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
ggml/src/ggml-sycl/dpct/helper.hpp CHANGED
@@ -874,7 +874,7 @@ namespace dpct
874
  inline std::string get_preferred_gpu_platform_name() {
875
  std::string result;
876
 
877
- std::string filter = "level-zero";
878
  char* env = getenv("ONEAPI_DEVICE_SELECTOR");
879
  if (env) {
880
  if (std::strstr(env, "level_zero")) {
@@ -892,11 +892,24 @@ namespace dpct
892
  else {
893
  throw std::runtime_error("invalid device filter: " + std::string(env));
894
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
895
  }
896
 
897
- auto plaform_list = sycl::platform::get_platforms();
898
 
899
- for (const auto& platform : plaform_list) {
900
  auto devices = platform.get_devices();
901
  auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
902
  return d.is_gpu();
 
874
  inline std::string get_preferred_gpu_platform_name() {
875
  std::string result;
876
 
877
+ std::string filter = "";
878
  char* env = getenv("ONEAPI_DEVICE_SELECTOR");
879
  if (env) {
880
  if (std::strstr(env, "level_zero")) {
 
892
  else {
893
  throw std::runtime_error("invalid device filter: " + std::string(env));
894
  }
895
+ } else {
896
+ auto default_device = sycl::device(sycl::default_selector_v);
897
+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
898
+
899
+ if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
900
+ filter = "level-zero";
901
+ }
902
+ else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
903
+ filter = "cuda";
904
+ }
905
+ else if (std::strstr(default_platform_name.c_str(), "HIP")) {
906
+ filter = "hip";
907
+ }
908
  }
909
 
910
+ auto platform_list = sycl::platform::get_platforms();
911
 
912
+ for (const auto& platform : platform_list) {
913
  auto devices = platform.get_devices();
914
  auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
915
  return d.is_gpu();