Spaces:
Runtime error
Runtime error
Commit
·
2c5665b
1
Parent(s):
d111436
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,9 @@ import time
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
import yt_dlp as youtube_dl
|
| 9 |
from gradio_client import Client
|
| 10 |
from pyannote.audio import Pipeline
|
| 11 |
-
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 12 |
|
| 13 |
|
| 14 |
YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
|
|
@@ -189,11 +187,11 @@ def align(transcription, segments, group_by_speaker=True):
|
|
| 189 |
return transcription
|
| 190 |
|
| 191 |
|
| 192 |
-
def transcribe(audio_path, group_by_speaker=True):
|
| 193 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
| 194 |
job = client.submit(
|
| 195 |
audio_path,
|
| 196 |
-
|
| 197 |
True,
|
| 198 |
api_name="/predict_1",
|
| 199 |
)
|
|
@@ -211,11 +209,11 @@ def transcribe(audio_path, group_by_speaker=True):
|
|
| 211 |
return transcription
|
| 212 |
|
| 213 |
|
| 214 |
-
def transcribe_yt(yt_url, group_by_speaker=True):
|
| 215 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
| 216 |
job = client.submit(
|
| 217 |
yt_url,
|
| 218 |
-
|
| 219 |
True,
|
| 220 |
api_name="/predict_2",
|
| 221 |
)
|
|
@@ -224,17 +222,8 @@ def transcribe_yt(yt_url, group_by_speaker=True):
|
|
| 224 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 225 |
filepath = os.path.join(tmpdirname, "video.mp4")
|
| 226 |
download_yt_audio(yt_url, filepath)
|
|
|
|
| 227 |
|
| 228 |
-
with open(filepath, "rb") as f:
|
| 229 |
-
inputs = f.read()
|
| 230 |
-
|
| 231 |
-
inputs = ffmpeg_read(inputs, SAMPLING_RATE)
|
| 232 |
-
inputs = torch.from_numpy(inputs).float()
|
| 233 |
-
inputs = inputs.unsqueeze(0)
|
| 234 |
-
|
| 235 |
-
diarization = diarization_pipeline(
|
| 236 |
-
{"waveform": inputs, "sample_rate": SAMPLING_RATE},
|
| 237 |
-
)
|
| 238 |
segments = diarization.for_json()["content"]
|
| 239 |
|
| 240 |
# only fetch the transcription result after performing diarization
|
|
@@ -257,6 +246,7 @@ microphone = gr.Interface(
|
|
| 257 |
fn=transcribe,
|
| 258 |
inputs=[
|
| 259 |
gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
|
|
|
|
| 260 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
| 261 |
],
|
| 262 |
outputs=[
|
|
@@ -272,6 +262,7 @@ audio_file = gr.Interface(
|
|
| 272 |
fn=transcribe,
|
| 273 |
inputs=[
|
| 274 |
gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
|
|
|
|
| 275 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
| 276 |
],
|
| 277 |
outputs=[
|
|
@@ -287,6 +278,7 @@ youtube = gr.Interface(
|
|
| 287 |
fn=transcribe_yt,
|
| 288 |
inputs=[
|
| 289 |
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
|
|
|
|
| 290 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
| 291 |
],
|
| 292 |
outputs=[
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
import yt_dlp as youtube_dl
|
| 8 |
from gradio_client import Client
|
| 9 |
from pyannote.audio import Pipeline
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
|
|
|
|
| 187 |
return transcription
|
| 188 |
|
| 189 |
|
| 190 |
+
def transcribe(audio_path, task="transcribe", group_by_speaker=True, progress=gr.Progress()):
|
| 191 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
| 192 |
job = client.submit(
|
| 193 |
audio_path,
|
| 194 |
+
task,
|
| 195 |
True,
|
| 196 |
api_name="/predict_1",
|
| 197 |
)
|
|
|
|
| 209 |
return transcription
|
| 210 |
|
| 211 |
|
| 212 |
+
def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.Progress()):
|
| 213 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
| 214 |
job = client.submit(
|
| 215 |
yt_url,
|
| 216 |
+
task,
|
| 217 |
True,
|
| 218 |
api_name="/predict_2",
|
| 219 |
)
|
|
|
|
| 222 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 223 |
filepath = os.path.join(tmpdirname, "video.mp4")
|
| 224 |
download_yt_audio(yt_url, filepath)
|
| 225 |
+
diarization = diarization_pipeline(filepath)
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
segments = diarization.for_json()["content"]
|
| 228 |
|
| 229 |
# only fetch the transcription result after performing diarization
|
|
|
|
| 246 |
fn=transcribe,
|
| 247 |
inputs=[
|
| 248 |
gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
|
| 249 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
| 250 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
| 251 |
],
|
| 252 |
outputs=[
|
|
|
|
| 262 |
fn=transcribe,
|
| 263 |
inputs=[
|
| 264 |
gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
|
| 265 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
| 266 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
| 267 |
],
|
| 268 |
outputs=[
|
|
|
|
| 278 |
fn=transcribe_yt,
|
| 279 |
inputs=[
|
| 280 |
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
|
| 281 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
| 282 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
| 283 |
],
|
| 284 |
outputs=[
|