ggerganov commited on
Commit
87b427e
·
unverified ·
1 Parent(s): 7925ae3

whisper : fix gpu device selection (#2728)

Browse files
Files changed (1) hide show
  1. src/whisper.cpp +35 -13
src/whisper.cpp CHANGED
@@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1235
  static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1236
  ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1237
 
 
 
 
1238
  if (params.use_gpu) {
1239
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1240
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1241
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1242
- WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1243
- ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1244
- if (!result) {
1245
- WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
 
 
1246
  }
1247
- return result;
1248
  }
1249
  }
1250
  }
1251
 
1252
- return nullptr;
 
 
 
 
 
 
 
 
 
 
 
1253
  }
1254
 
1255
  static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
1283
  }
1284
 
1285
  static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
 
 
1286
  if (!params.use_gpu) {
1287
- return ggml_backend_cpu_buffer_type();
1288
  }
1289
 
1290
- // if we have a GPU device - use it
1291
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1292
  ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1293
  if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1294
- WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1295
- return ggml_backend_dev_buffer_type(dev);
 
 
 
 
 
1296
  }
1297
  }
1298
 
1299
- return ggml_backend_cpu_buffer_type();
1300
  }
1301
 
1302
  // load the model from a ggml file
 
1235
  static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1236
  ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1237
 
1238
+ ggml_backend_dev_t dev = nullptr;
1239
+
1240
+ int cnt = 0;
1241
  if (params.use_gpu) {
1242
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1243
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
1244
+ if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1245
+ if (cnt == 0 || cnt == params.gpu_device) {
1246
+ dev = dev_cur;
1247
+ }
1248
+
1249
+ if (++cnt > params.gpu_device) {
1250
+ break;
1251
  }
 
1252
  }
1253
  }
1254
  }
1255
 
1256
+ if (dev == nullptr) {
1257
+ WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
1258
+ return nullptr;
1259
+ }
1260
+
1261
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1262
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1263
+ if (!result) {
1264
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1265
+ }
1266
+
1267
+ return result;
1268
  }
1269
 
1270
  static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
 
1298
  }
1299
 
1300
  static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1301
+ ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
1302
+
1303
  if (!params.use_gpu) {
1304
+ return result;
1305
  }
1306
 
1307
+ int cnt = 0;
1308
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1309
  ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1310
  if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1311
+ if (cnt == 0 || cnt == params.gpu_device) {
1312
+ result = ggml_backend_dev_buffer_type(dev);
1313
+ }
1314
+
1315
+ if (++cnt > params.gpu_device) {
1316
+ break;
1317
+ }
1318
  }
1319
  }
1320
 
1321
+ return result;
1322
  }
1323
 
1324
  // load the model from a ggml file