Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -69,6 +69,24 @@ PROVIDERS = [
|
|
| 69 |
"nscale",
|
| 70 |
]
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
templates = Jinja2Templates(directory="templates")
|
| 73 |
|
| 74 |
async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]:
|
|
@@ -99,6 +117,33 @@ async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) ->
|
|
| 99 |
"monthly_requests_int": 0
|
| 100 |
}
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
async def collect_and_store_data():
|
| 103 |
"""Collect current data and store it in the dataset"""
|
| 104 |
if not HF_TOKEN:
|
|
@@ -309,6 +354,61 @@ async def get_historical_data():
|
|
| 309 |
"message": "Historical data temporarily unavailable"
|
| 310 |
}
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
@app.post("/api/collect-now")
|
| 313 |
async def trigger_data_collection(background_tasks: BackgroundTasks):
|
| 314 |
"""Manual trigger for data collection"""
|
|
|
|
| 69 |
"nscale",
|
| 70 |
]
|
| 71 |
|
| 72 |
+
# Mapping from display provider names to inference provider API names
|
| 73 |
+
PROVIDER_TO_INFERENCE_NAME = {
|
| 74 |
+
"togethercomputer": "together",
|
| 75 |
+
"fal": "fal-ai",
|
| 76 |
+
"sambanovasystems": "sambanova",
|
| 77 |
+
"Hyperbolic": "hyperbolic",
|
| 78 |
+
"CohereLabs": "cohere",
|
| 79 |
+
# Other providers may not have inference provider support or use different names
|
| 80 |
+
"fireworks-ai": "fireworks-ai",
|
| 81 |
+
"nebius": "nebius",
|
| 82 |
+
"groq": "groq",
|
| 83 |
+
"cerebras": "cerebras",
|
| 84 |
+
"replicate": "replicate",
|
| 85 |
+
"novita": "novita",
|
| 86 |
+
"featherless-ai": "featherless-ai",
|
| 87 |
+
"nscale": "nscale",
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
templates = Jinja2Templates(directory="templates")
|
| 91 |
|
| 92 |
async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]:
|
|
|
|
| 117 |
"monthly_requests_int": 0
|
| 118 |
}
|
| 119 |
|
| 120 |
+
async def get_provider_models(session: aiohttp.ClientSession, provider: str) -> List[str]:
|
| 121 |
+
"""Get supported models for a provider from HuggingFace API"""
|
| 122 |
+
if not HF_TOKEN:
|
| 123 |
+
return []
|
| 124 |
+
|
| 125 |
+
# Map display provider name to inference provider API name
|
| 126 |
+
inference_provider = PROVIDER_TO_INFERENCE_NAME.get(provider)
|
| 127 |
+
if not inference_provider:
|
| 128 |
+
logger.warning(f"No inference provider mapping found for {provider}")
|
| 129 |
+
return []
|
| 130 |
+
|
| 131 |
+
url = f"https://huggingface.co/api/models?inference_provider={inference_provider}&limit=50&sort=downloads&direction=-1"
|
| 132 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
async with session.get(url, headers=headers) as response:
|
| 136 |
+
if response.status == 200:
|
| 137 |
+
models_data = await response.json()
|
| 138 |
+
model_ids = [model.get('id', '') for model in models_data if model.get('id')]
|
| 139 |
+
return model_ids
|
| 140 |
+
else:
|
| 141 |
+
logger.warning(f"Failed to fetch models for {provider} (inference_provider={inference_provider}): {response.status}")
|
| 142 |
+
return []
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"Error fetching models for {provider} (inference_provider={inference_provider}): {e}")
|
| 145 |
+
return []
|
| 146 |
+
|
| 147 |
async def collect_and_store_data():
|
| 148 |
"""Collect current data and store it in the dataset"""
|
| 149 |
if not HF_TOKEN:
|
|
|
|
| 354 |
"message": "Historical data temporarily unavailable"
|
| 355 |
}
|
| 356 |
|
| 357 |
+
@app.get("/api/models")
|
| 358 |
+
async def get_provider_models_data():
|
| 359 |
+
"""API endpoint to get supported models matrix for all providers"""
|
| 360 |
+
if not HF_TOKEN:
|
| 361 |
+
return {"error": "HF_TOKEN required for models data", "matrix": [], "providers": PROVIDERS}
|
| 362 |
+
|
| 363 |
+
async with aiohttp.ClientSession() as session:
|
| 364 |
+
tasks = [get_provider_models(session, provider) for provider in PROVIDERS]
|
| 365 |
+
results = await asyncio.gather(*tasks)
|
| 366 |
+
|
| 367 |
+
# Create provider -> models mapping
|
| 368 |
+
provider_models = {}
|
| 369 |
+
all_models = set()
|
| 370 |
+
|
| 371 |
+
for provider, models in zip(PROVIDERS, results):
|
| 372 |
+
provider_models[provider] = set(models)
|
| 373 |
+
all_models.update(models)
|
| 374 |
+
|
| 375 |
+
# Convert to list and sort by popularity (number of providers supporting each model)
|
| 376 |
+
model_popularity = []
|
| 377 |
+
for model in all_models:
|
| 378 |
+
provider_count = sum(1 for provider in PROVIDERS if model in provider_models.get(provider, set()))
|
| 379 |
+
model_popularity.append((model, provider_count))
|
| 380 |
+
|
| 381 |
+
# Sort by popularity (descending) then by model name
|
| 382 |
+
model_popularity.sort(key=lambda x: (-x[1], x[0]))
|
| 383 |
+
|
| 384 |
+
# Build matrix data
|
| 385 |
+
matrix = []
|
| 386 |
+
for model_id, popularity in model_popularity:
|
| 387 |
+
row = {
|
| 388 |
+
"model_id": model_id,
|
| 389 |
+
"total_providers": popularity,
|
| 390 |
+
"providers": {}
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
for provider in PROVIDERS:
|
| 394 |
+
row["providers"][provider] = model_id in provider_models.get(provider, set())
|
| 395 |
+
|
| 396 |
+
matrix.append(row)
|
| 397 |
+
|
| 398 |
+
# Calculate totals per provider
|
| 399 |
+
provider_totals = {}
|
| 400 |
+
for provider in PROVIDERS:
|
| 401 |
+
provider_totals[provider] = len(provider_models.get(provider, set()))
|
| 402 |
+
|
| 403 |
+
return {
|
| 404 |
+
"matrix": matrix,
|
| 405 |
+
"providers": PROVIDERS,
|
| 406 |
+
"provider_totals": provider_totals,
|
| 407 |
+
"provider_mapping": PROVIDER_TO_INFERENCE_NAME,
|
| 408 |
+
"total_models": len(all_models),
|
| 409 |
+
"last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
@app.post("/api/collect-now")
|
| 413 |
async def trigger_data_collection(background_tasks: BackgroundTasks):
|
| 414 |
"""Manual trigger for data collection"""
|