Spaces:
Running
Running
command : adding guided mode
Browse files- examples/command/command.cpp +272 -64
- examples/command/commands.txt +9 -0
- whisper.cpp +4 -4
examples/command/command.cpp
CHANGED
|
@@ -19,6 +19,7 @@
|
|
| 19 |
#include <string>
|
| 20 |
#include <thread>
|
| 21 |
#include <vector>
|
|
|
|
| 22 |
|
| 23 |
// command-line parameters
|
| 24 |
struct whisper_params {
|
|
@@ -41,6 +42,7 @@ struct whisper_params {
|
|
| 41 |
std::string language = "en";
|
| 42 |
std::string model = "models/ggml-base.en.bin";
|
| 43 |
std::string fname_out = "";
|
|
|
|
| 44 |
};
|
| 45 |
|
| 46 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
@@ -68,6 +70,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 68 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 69 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 70 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
|
|
| 71 |
else {
|
| 72 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 73 |
whisper_print_usage(argc, argv, params);
|
|
@@ -83,22 +86,23 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 83 |
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
| 84 |
fprintf(stderr, "\n");
|
| 85 |
fprintf(stderr, "options:\n");
|
| 86 |
-
fprintf(stderr, " -h,
|
| 87 |
-
fprintf(stderr, " -t N,
|
| 88 |
-
fprintf(stderr, " -pms N,
|
| 89 |
-
fprintf(stderr, " -cms N,
|
| 90 |
-
fprintf(stderr, " -c ID,
|
| 91 |
-
fprintf(stderr, " -mt N,
|
| 92 |
-
fprintf(stderr, " -ac N,
|
| 93 |
-
fprintf(stderr, " -vth N,
|
| 94 |
-
fprintf(stderr, " -fth N,
|
| 95 |
-
fprintf(stderr, " -su,
|
| 96 |
-
fprintf(stderr, " -tr,
|
| 97 |
-
fprintf(stderr, " -ps,
|
| 98 |
-
fprintf(stderr, " -pe,
|
| 99 |
-
fprintf(stderr, " -l LANG,
|
| 100 |
-
fprintf(stderr, " -m FNAME,
|
| 101 |
-
fprintf(stderr, " -f FNAME,
|
|
|
|
| 102 |
fprintf(stderr, "\n");
|
| 103 |
}
|
| 104 |
|
|
@@ -484,6 +488,28 @@ float similarity(const std::string & s0, const std::string & s1) {
|
|
| 484 |
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
| 485 |
}
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
int main(int argc, char ** argv) {
|
| 488 |
whisper_params params;
|
| 489 |
|
|
@@ -521,7 +547,6 @@ int main(int argc, char ** argv) {
|
|
| 521 |
fprintf(stderr, "\n");
|
| 522 |
}
|
| 523 |
|
| 524 |
-
|
| 525 |
// init audio
|
| 526 |
|
| 527 |
audio_async audio(30*1000);
|
|
@@ -532,6 +557,8 @@ int main(int argc, char ** argv) {
|
|
| 532 |
|
| 533 |
audio.resume();
|
| 534 |
|
|
|
|
|
|
|
| 535 |
bool is_running = true;
|
| 536 |
bool have_prompt = false;
|
| 537 |
bool ask_prompt = true;
|
|
@@ -542,7 +569,94 @@ int main(int argc, char ** argv) {
|
|
| 542 |
std::vector<float> pcmf32_cur;
|
| 543 |
std::vector<float> pcmf32_prompt;
|
| 544 |
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
// main loop
|
| 548 |
while (is_running) {
|
|
@@ -568,78 +682,172 @@ int main(int argc, char ** argv) {
|
|
| 568 |
// delay
|
| 569 |
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 570 |
|
| 571 |
-
if (
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
fprintf(stdout, "\n");
|
| 575 |
|
| 576 |
-
ask_prompt
|
| 577 |
-
|
|
|
|
|
|
|
| 578 |
|
| 579 |
-
|
|
|
|
| 580 |
|
| 581 |
-
|
| 582 |
-
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
|
| 587 |
-
|
| 588 |
-
audio.get(params.prompt_ms, pcmf32_cur);
|
| 589 |
|
| 590 |
-
|
| 591 |
|
| 592 |
-
|
| 593 |
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
|
|
|
| 599 |
} else {
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
| 603 |
-
fprintf(stdout, "\n");
|
| 604 |
|
| 605 |
-
//
|
| 606 |
-
pcmf32_prompt
|
| 607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
}
|
| 609 |
-
} else {
|
| 610 |
-
audio.get(params.command_ms, pcmf32_cur);
|
| 611 |
|
| 612 |
-
|
| 613 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
-
|
| 616 |
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
|
| 619 |
-
|
|
|
|
| 620 |
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
size_t best_len = 0;
|
| 624 |
-
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
| 625 |
-
const auto prompt = txt.substr(0, n);
|
| 626 |
|
| 627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
-
|
|
|
|
| 630 |
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
|
|
|
| 635 |
}
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
-
|
|
|
|
| 640 |
fprintf(stdout, "\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
}
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
audio.clear();
|
| 644 |
}
|
| 645 |
}
|
|
|
|
| 19 |
#include <string>
|
| 20 |
#include <thread>
|
| 21 |
#include <vector>
|
| 22 |
+
#include <map>
|
| 23 |
|
| 24 |
// command-line parameters
|
| 25 |
struct whisper_params {
|
|
|
|
| 42 |
std::string language = "en";
|
| 43 |
std::string model = "models/ggml-base.en.bin";
|
| 44 |
std::string fname_out = "";
|
| 45 |
+
std::string commands = "";
|
| 46 |
};
|
| 47 |
|
| 48 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 70 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 71 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 72 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
| 73 |
+
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
| 74 |
else {
|
| 75 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 76 |
whisper_print_usage(argc, argv, params);
|
|
|
|
| 86 |
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
| 87 |
fprintf(stderr, "\n");
|
| 88 |
fprintf(stderr, "options:\n");
|
| 89 |
+
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
| 90 |
+
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
| 91 |
+
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
| 92 |
+
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
| 93 |
+
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
| 94 |
+
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
| 95 |
+
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
| 96 |
+
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
| 97 |
+
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
| 98 |
+
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
| 99 |
+
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
| 100 |
+
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 101 |
+
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 102 |
+
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 103 |
+
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 104 |
+
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
| 105 |
+
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
| 106 |
fprintf(stderr, "\n");
|
| 107 |
}
|
| 108 |
|
|
|
|
| 488 |
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
| 489 |
}
|
| 490 |
|
| 491 |
+
std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
| 492 |
+
std::vector<std::string> allowed_commands;
|
| 493 |
+
|
| 494 |
+
std::ifstream ifs(fname);
|
| 495 |
+
if (!ifs.is_open()) {
|
| 496 |
+
return allowed_commands;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
std::string line;
|
| 500 |
+
while (std::getline(ifs, line)) {
|
| 501 |
+
line = trim(line);
|
| 502 |
+
if (line.empty()) {
|
| 503 |
+
continue;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
std::transform(line.begin(), line.end(),line.begin(), ::tolower);
|
| 507 |
+
allowed_commands.push_back(std::move(line));
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
return allowed_commands;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
int main(int argc, char ** argv) {
|
| 514 |
whisper_params params;
|
| 515 |
|
|
|
|
| 547 |
fprintf(stderr, "\n");
|
| 548 |
}
|
| 549 |
|
|
|
|
| 550 |
// init audio
|
| 551 |
|
| 552 |
audio_async audio(30*1000);
|
|
|
|
| 557 |
|
| 558 |
audio.resume();
|
| 559 |
|
| 560 |
+
int max_len = 0;
|
| 561 |
+
|
| 562 |
bool is_running = true;
|
| 563 |
bool have_prompt = false;
|
| 564 |
bool ask_prompt = true;
|
|
|
|
| 569 |
std::vector<float> pcmf32_cur;
|
| 570 |
std::vector<float> pcmf32_prompt;
|
| 571 |
|
| 572 |
+
std::vector<std::string> allowed_commands;
|
| 573 |
+
std::vector<std::vector<whisper_token>> allowed_tokens;
|
| 574 |
+
|
| 575 |
+
std::string k_prompt = "";
|
| 576 |
+
std::vector<whisper_token> k_tokens;
|
| 577 |
+
|
| 578 |
+
if (params.commands != "") {
|
| 579 |
+
fprintf(stderr, "\n");
|
| 580 |
+
fprintf(stderr, "%s: guided mode\n", __func__);
|
| 581 |
+
|
| 582 |
+
allowed_commands = read_allowed_commands(params.commands);
|
| 583 |
+
|
| 584 |
+
if (allowed_commands.empty()) {
|
| 585 |
+
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
| 586 |
+
return 2;
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
for (const auto & cmd : allowed_commands) {
|
| 590 |
+
whisper_token tokens[1024];
|
| 591 |
+
allowed_tokens.emplace_back();
|
| 592 |
+
|
| 593 |
+
for (int l = 0; l < cmd.size(); ++l) {
|
| 594 |
+
// NOTE: very important to add the whitespace !
|
| 595 |
+
// the reason is that the first decoded token starts with a whitespace too!
|
| 596 |
+
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
| 597 |
+
|
| 598 |
+
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
| 599 |
+
if (n < 0) {
|
| 600 |
+
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
| 601 |
+
return 3;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
if (n == 1) {
|
| 605 |
+
allowed_tokens.back().push_back(tokens[0]);
|
| 606 |
+
}
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
max_len = std::max(max_len, (int) cmd.size());
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
| 613 |
+
fprintf(stderr, "\n");
|
| 614 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 615 |
+
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
| 616 |
+
for (const auto & token : allowed_tokens[i]) {
|
| 617 |
+
fprintf(stderr, " %d", token);
|
| 618 |
+
}
|
| 619 |
+
fprintf(stderr, " ]\n");
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
k_prompt = "select one from the available words: ";
|
| 623 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 624 |
+
if (i > 0) {
|
| 625 |
+
k_prompt += ", ";
|
| 626 |
+
}
|
| 627 |
+
k_prompt += allowed_commands[i];
|
| 628 |
+
}
|
| 629 |
+
k_prompt += ". selected word: ";
|
| 630 |
+
|
| 631 |
+
// tokenize prompt
|
| 632 |
+
{
|
| 633 |
+
k_tokens.resize(1024);
|
| 634 |
+
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
| 635 |
+
if (n < 0) {
|
| 636 |
+
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
| 637 |
+
return 4;
|
| 638 |
+
}
|
| 639 |
+
k_tokens.resize(n);
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
fprintf(stderr, "\n");
|
| 643 |
+
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
| 644 |
+
fprintf(stderr, "%s: tokens: [", __func__);
|
| 645 |
+
for (const auto & token : k_tokens) {
|
| 646 |
+
fprintf(stderr, " %d", token);
|
| 647 |
+
}
|
| 648 |
+
fprintf(stderr, " ]\n");
|
| 649 |
+
|
| 650 |
+
fprintf(stderr, "\n");
|
| 651 |
+
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
| 652 |
+
fprintf(stderr, "\n");
|
| 653 |
+
|
| 654 |
+
} else {
|
| 655 |
+
fprintf(stderr, "\n");
|
| 656 |
+
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
| 657 |
+
|
| 658 |
+
k_prompt = "Ok Whisper, start listening for commands.";
|
| 659 |
+
}
|
| 660 |
|
| 661 |
// main loop
|
| 662 |
while (is_running) {
|
|
|
|
| 682 |
// delay
|
| 683 |
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 684 |
|
| 685 |
+
if (allowed_commands.empty()) {
|
| 686 |
+
// general-purpose mode
|
| 687 |
+
// freely transcribe the voice into text
|
|
|
|
| 688 |
|
| 689 |
+
if (ask_prompt) {
|
| 690 |
+
fprintf(stdout, "\n");
|
| 691 |
+
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
| 692 |
+
fprintf(stdout, "\n");
|
| 693 |
|
| 694 |
+
ask_prompt = false;
|
| 695 |
+
}
|
| 696 |
|
| 697 |
+
{
|
| 698 |
+
int64_t t_ms = 0;
|
| 699 |
|
| 700 |
+
audio.get(2000, pcmf32_cur);
|
| 701 |
+
|
| 702 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 703 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 704 |
+
|
| 705 |
+
if (!have_prompt) {
|
| 706 |
+
// wait for activation phrase
|
| 707 |
+
audio.get(params.prompt_ms, pcmf32_cur);
|
| 708 |
|
| 709 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
|
|
|
| 710 |
|
| 711 |
+
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
| 712 |
|
| 713 |
+
const float sim = similarity(txt, k_prompt);
|
| 714 |
|
| 715 |
+
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
| 716 |
+
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
| 717 |
+
ask_prompt = true;
|
| 718 |
+
} else {
|
| 719 |
+
fprintf(stdout, "\n");
|
| 720 |
+
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
| 721 |
+
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
| 722 |
+
fprintf(stdout, "\n");
|
| 723 |
|
| 724 |
+
// save the audio for the prompt
|
| 725 |
+
pcmf32_prompt = pcmf32_cur;
|
| 726 |
+
have_prompt = true;
|
| 727 |
+
}
|
| 728 |
} else {
|
| 729 |
+
// we have heard the activation phrase, now detect the commands
|
| 730 |
+
audio.get(params.command_ms, pcmf32_cur);
|
|
|
|
|
|
|
| 731 |
|
| 732 |
+
// prepend the prompt audio
|
| 733 |
+
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
| 734 |
+
|
| 735 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
| 736 |
+
|
| 737 |
+
prob = 100.0f*(prob - prob0);
|
| 738 |
+
|
| 739 |
+
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
| 740 |
+
|
| 741 |
+
// find the prompt in the text
|
| 742 |
+
float best_sim = 0.0f;
|
| 743 |
+
size_t best_len = 0;
|
| 744 |
+
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
| 745 |
+
const auto prompt = txt.substr(0, n);
|
| 746 |
+
|
| 747 |
+
const float sim = similarity(prompt, k_prompt);
|
| 748 |
+
|
| 749 |
+
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
| 750 |
+
|
| 751 |
+
if (sim > best_sim) {
|
| 752 |
+
best_sim = sim;
|
| 753 |
+
best_len = n;
|
| 754 |
+
}
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
const std::string command = ::trim(txt.substr(best_len));
|
| 758 |
+
|
| 759 |
+
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
| 760 |
+
fprintf(stdout, "\n");
|
| 761 |
}
|
|
|
|
|
|
|
| 762 |
|
| 763 |
+
audio.clear();
|
| 764 |
+
}
|
| 765 |
+
}
|
| 766 |
+
} else {
|
| 767 |
+
// command-list mode
|
| 768 |
+
// guide the transcription to match the most likely command from a provided list
|
| 769 |
+
|
| 770 |
+
audio.get(2000, pcmf32_cur);
|
| 771 |
+
|
| 772 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 773 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 774 |
+
|
| 775 |
+
const auto t_start = std::chrono::high_resolution_clock::now();
|
| 776 |
|
| 777 |
+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 778 |
|
| 779 |
+
wparams.print_progress = false;
|
| 780 |
+
wparams.print_special = params.print_special;
|
| 781 |
+
wparams.print_realtime = false;
|
| 782 |
+
wparams.print_timestamps = !params.no_timestamps;
|
| 783 |
+
wparams.translate = params.translate;
|
| 784 |
+
wparams.no_context = true;
|
| 785 |
+
wparams.single_segment = true;
|
| 786 |
+
wparams.max_tokens = 1;
|
| 787 |
+
wparams.language = params.language.c_str();
|
| 788 |
+
wparams.n_threads = params.n_threads;
|
| 789 |
|
| 790 |
+
wparams.audio_ctx = params.audio_ctx;
|
| 791 |
+
wparams.speed_up = params.speed_up;
|
| 792 |
|
| 793 |
+
wparams.prompt_tokens = k_tokens.data();
|
| 794 |
+
wparams.prompt_n_tokens = k_tokens.size();
|
|
|
|
|
|
|
|
|
|
| 795 |
|
| 796 |
+
// run the transformer and a single decoding pass
|
| 797 |
+
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
| 798 |
+
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
| 799 |
+
break;
|
| 800 |
+
}
|
| 801 |
|
| 802 |
+
const auto * probs = whisper_get_probs(ctx);
|
| 803 |
+
std::vector<std::pair<float, int>> probs_id;
|
| 804 |
|
| 805 |
+
double psum = 0.0;
|
| 806 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 807 |
+
probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i));
|
| 808 |
+
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
| 809 |
+
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 810 |
}
|
| 811 |
+
probs_id.back().first /= allowed_tokens[i].size();
|
| 812 |
+
psum += probs_id.back().first;
|
| 813 |
+
}
|
| 814 |
|
| 815 |
+
// normalize
|
| 816 |
+
for (auto & p : probs_id) {
|
| 817 |
+
p.first /= psum;
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
// sort descending
|
| 821 |
+
{
|
| 822 |
+
using pair_type = decltype(probs_id)::value_type;
|
| 823 |
+
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
| 824 |
+
return a.first > b.first;
|
| 825 |
+
});
|
| 826 |
+
}
|
| 827 |
|
| 828 |
+
// print the commands and the respective probabilities
|
| 829 |
+
{
|
| 830 |
fprintf(stdout, "\n");
|
| 831 |
+
for (const auto & cmd : probs_id) {
|
| 832 |
+
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
| 833 |
+
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
|
| 834 |
+
fprintf(stdout, "%f ", probs[allowed_tokens[cmd.second][i]]);
|
| 835 |
+
}
|
| 836 |
+
fprintf(stdout, "\n");
|
| 837 |
+
}
|
| 838 |
}
|
| 839 |
|
| 840 |
+
// best command
|
| 841 |
+
{
|
| 842 |
+
fprintf(stdout, "\n");
|
| 843 |
+
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
| 844 |
+
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
|
| 845 |
+
(int) std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - t_start).count());
|
| 846 |
+
fprintf(stdout, "\n");
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
const auto t_end = std::chrono::high_resolution_clock::now();
|
| 850 |
+
|
| 851 |
audio.clear();
|
| 852 |
}
|
| 853 |
}
|
examples/command/commands.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
enable
|
| 2 |
+
disable
|
| 3 |
+
cat
|
| 4 |
+
dog
|
| 5 |
+
apple
|
| 6 |
+
red
|
| 7 |
+
blue
|
| 8 |
+
green
|
| 9 |
+
lightblue
|
whisper.cpp
CHANGED
|
@@ -2826,13 +2826,13 @@ int whisper_full(
|
|
| 2826 |
|
| 2827 |
//{
|
| 2828 |
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
| 2829 |
-
// printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
| 2830 |
//}
|
| 2831 |
|
| 2832 |
// end of segment
|
| 2833 |
-
if (token.id == whisper_token_eot(ctx) ||
|
| 2834 |
-
(params.max_tokens > 0 && i
|
| 2835 |
-
(has_ts && seek + seek_delta + 100 >= seek_end)
|
| 2836 |
) {
|
| 2837 |
if (result_len == 0) {
|
| 2838 |
if (seek + seek_delta + 100 >= seek_end) {
|
|
|
|
| 2826 |
|
| 2827 |
//{
|
| 2828 |
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
| 2829 |
+
// printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
| 2830 |
//}
|
| 2831 |
|
| 2832 |
// end of segment
|
| 2833 |
+
if (token.id == whisper_token_eot(ctx) || // end of text token
|
| 2834 |
+
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
| 2835 |
+
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
| 2836 |
) {
|
| 2837 |
if (result_len == 0) {
|
| 2838 |
if (seek + seek_delta + 100 >= seek_end) {
|