Spaces:
Running
on
Zero
Running
on
Zero
add tree+image
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ import multiprocessing as mp
|
|
| 16 |
from einops import rearrange
|
| 17 |
from matplotlib import pyplot as plt
|
| 18 |
import matplotlib
|
|
|
|
| 19 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
| 20 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
| 21 |
|
|
@@ -2342,13 +2343,99 @@ with demo:
|
|
| 2342 |
|
| 2343 |
pil_image = Image.fromarray(image)
|
| 2344 |
return pil_image
|
| 2345 |
-
|
| 2346 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2347 |
if len(eigvecs) == 0:
|
| 2348 |
gr.Warning("Please run NCUT first.")
|
| 2349 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2350 |
eigvecs = torch.tensor(eigvecs)
|
| 2351 |
-
|
| 2352 |
gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
|
| 2353 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
| 2354 |
from sklearn.manifold import TSNE
|
|
@@ -2357,8 +2444,8 @@ with demo:
|
|
| 2357 |
torch.manual_seed(seed)
|
| 2358 |
np.random.seed(seed)
|
| 2359 |
|
| 2360 |
-
fps_idx = farthest_point_sampling(
|
| 2361 |
-
fps_eigvecs =
|
| 2362 |
fps_eigvecs = fps_eigvecs.numpy()
|
| 2363 |
|
| 2364 |
tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
|
|
@@ -2371,19 +2458,29 @@ with demo:
|
|
| 2371 |
metric='cosine',
|
| 2372 |
random_state=seed,
|
| 2373 |
).fit_transform(fps_eigvecs)
|
|
|
|
|
|
|
|
|
|
| 2374 |
|
| 2375 |
edges = build_tree(tsne_embed)
|
|
|
|
|
|
|
| 2376 |
|
| 2377 |
# Plot the t-SNE points
|
| 2378 |
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
|
| 2379 |
|
| 2380 |
-
|
|
|
|
|
|
|
|
|
|
| 2381 |
|
| 2382 |
-
|
| 2383 |
-
|
| 2384 |
-
|
| 2385 |
-
|
| 2386 |
-
|
|
|
|
|
|
|
| 2387 |
gr.Markdown('---')
|
| 2388 |
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
| 2389 |
gr.Markdown('---')
|
|
|
|
| 16 |
from einops import rearrange
|
| 17 |
from matplotlib import pyplot as plt
|
| 18 |
import matplotlib
|
| 19 |
+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
|
| 20 |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"]
|
| 21 |
DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"]
|
| 22 |
|
|
|
|
| 2343 |
|
| 2344 |
pil_image = Image.fromarray(image)
|
| 2345 |
return pil_image
|
| 2346 |
+
|
| 2347 |
+
def get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed):
|
| 2348 |
+
n_dots = fps_eigvecs.shape[0]
|
| 2349 |
+
if n_dots > max_display_dots:
|
| 2350 |
+
dots_idx = np.random.choice(n_dots, max_display_dots, replace=False)
|
| 2351 |
+
import fpsample
|
| 2352 |
+
dots_idx = fpsample.bucket_fps_kdline_sampling(tsne_embed, max_display_dots, 5).astype(np.int64)
|
| 2353 |
+
else:
|
| 2354 |
+
dots_idx = np.arange(n_dots)
|
| 2355 |
+
fps_eigvecs = fps_eigvecs[dots_idx]
|
| 2356 |
+
fps_tsne_rgb = fps_tsne_rgb[dots_idx]
|
| 2357 |
+
|
| 2358 |
+
heatmaps = eigvecs @ fps_eigvecs.T # [B, H, W, C] @ [N, C] -> [B, H, W, N]
|
| 2359 |
+
value = heatmaps.mean(1).mean(1) # [B, N]
|
| 2360 |
+
top1_image_idxs = value.argmax(axis=0) # [N]
|
| 2361 |
+
|
| 2362 |
+
def pad_image_with_border(image, border_color, border_width):
|
| 2363 |
+
new_image = np.ones((image.shape[0] + 2 * border_width, image.shape[1] + 2 * border_width, image.shape[2]), dtype=image.dtype)
|
| 2364 |
+
new_image[:, :] = border_color
|
| 2365 |
+
new_image[border_width:-border_width, border_width:-border_width] = image
|
| 2366 |
+
return new_image
|
| 2367 |
+
|
| 2368 |
+
top1_image_blended = []
|
| 2369 |
+
cm = matplotlib.colormaps['hot']
|
| 2370 |
+
for i_fps in range(len(top1_image_idxs)):
|
| 2371 |
+
image_idx = top1_image_idxs[i_fps]
|
| 2372 |
+
image = images[image_idx]
|
| 2373 |
+
heatmap = heatmaps[image_idx, :, :, i_fps]
|
| 2374 |
+
heatmap = cm(heatmap)
|
| 2375 |
+
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
|
| 2376 |
+
image = image.convert("RGB").resize((256, 256))
|
| 2377 |
+
heatmap = Image.fromarray(heatmap).resize((256, 256)).convert("RGB")
|
| 2378 |
+
blended = 0.5 * np.array(image) + 0.5 * np.array(heatmap)
|
| 2379 |
+
blended = np.clip(blended, 0, 255).astype(np.uint8)
|
| 2380 |
+
border_color = fps_tsne_rgb[i_fps, :3] * 255
|
| 2381 |
+
border_width = 20
|
| 2382 |
+
padded_image = pad_image_with_border(blended, border_color, border_width)
|
| 2383 |
+
top1_image_blended.append(padded_image)
|
| 2384 |
+
|
| 2385 |
+
return top1_image_blended, dots_idx
|
| 2386 |
+
|
| 2387 |
+
|
| 2388 |
+
|
| 2389 |
+
def plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne_rgb, max_display_dots=100):
|
| 2390 |
+
top1_image_blended, dots_idx = get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed)
|
| 2391 |
+
|
| 2392 |
+
# Plot the t-SNE points
|
| 2393 |
+
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
| 2394 |
+
ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb)
|
| 2395 |
+
ax.set_xticks([])
|
| 2396 |
+
ax.set_yticks([])
|
| 2397 |
+
ax.axis('off')
|
| 2398 |
+
ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1)
|
| 2399 |
+
ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1)
|
| 2400 |
+
|
| 2401 |
+
# Add the top1_image_blended to the scatter plot
|
| 2402 |
+
for i, (x, y) in enumerate(tsne_embed[dots_idx]):
|
| 2403 |
+
img = top1_image_blended[i]
|
| 2404 |
+
img = np.array(img)
|
| 2405 |
+
imgbox = OffsetImage(img, zoom=0.15)
|
| 2406 |
+
ab = AnnotationBbox(imgbox, (x, y), frameon=False)
|
| 2407 |
+
ax.add_artist(ab)
|
| 2408 |
+
ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb)
|
| 2409 |
+
|
| 2410 |
+
# Remove the white space around the plot
|
| 2411 |
+
fig.tight_layout(pad=0)
|
| 2412 |
+
|
| 2413 |
+
# Save the plot to an in-memory buffer
|
| 2414 |
+
buf = io.BytesIO()
|
| 2415 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
| 2416 |
+
buf.seek(0)
|
| 2417 |
+
|
| 2418 |
+
# Load the image into a NumPy array
|
| 2419 |
+
image = np.array(Image.open(buf))
|
| 2420 |
+
|
| 2421 |
+
# Close the buffer and plot
|
| 2422 |
+
buf.close()
|
| 2423 |
+
plt.close(fig)
|
| 2424 |
+
|
| 2425 |
+
pil_image = Image.fromarray(image)
|
| 2426 |
+
return pil_image
|
| 2427 |
+
|
| 2428 |
+
|
| 2429 |
+
def run_fps_tsne_hierarchical(image_gallery, eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0, max_display_dots=300):
|
| 2430 |
if len(eigvecs) == 0:
|
| 2431 |
gr.Warning("Please run NCUT first.")
|
| 2432 |
return
|
| 2433 |
+
images = [image[0] for image in image_gallery]
|
| 2434 |
+
if isinstance(images[0], str):
|
| 2435 |
+
images = [Image.open(image) for image in images]
|
| 2436 |
+
|
| 2437 |
eigvecs = torch.tensor(eigvecs)
|
| 2438 |
+
_eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
|
| 2439 |
gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
|
| 2440 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
| 2441 |
from sklearn.manifold import TSNE
|
|
|
|
| 2444 |
torch.manual_seed(seed)
|
| 2445 |
np.random.seed(seed)
|
| 2446 |
|
| 2447 |
+
fps_idx = farthest_point_sampling(_eigvecs, num_sample_fps)
|
| 2448 |
+
fps_eigvecs = _eigvecs[fps_idx]
|
| 2449 |
fps_eigvecs = fps_eigvecs.numpy()
|
| 2450 |
|
| 2451 |
tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
|
|
|
|
| 2458 |
metric='cosine',
|
| 2459 |
random_state=seed,
|
| 2460 |
).fit_transform(fps_eigvecs)
|
| 2461 |
+
# normalize = [-1, 1]
|
| 2462 |
+
tsne_embed[:, 0] = (tsne_embed[:, 0] - tsne_embed[:, 0].min()) / (tsne_embed[:, 0].max() - tsne_embed[:, 0].min()) * 2 - 1
|
| 2463 |
+
tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
|
| 2464 |
|
| 2465 |
edges = build_tree(tsne_embed)
|
| 2466 |
+
# edges = build_tree(fps_eigvecs, dist='cosine')
|
| 2467 |
+
# edges = build_tree(fps_tsne3d_rgb)
|
| 2468 |
|
| 2469 |
# Plot the t-SNE points
|
| 2470 |
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
|
| 2471 |
|
| 2472 |
+
# Plot the t-SNE points with image heatmaps
|
| 2473 |
+
big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
|
| 2474 |
+
|
| 2475 |
+
return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, big_pil_image
|
| 2476 |
|
| 2477 |
+
big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png')
|
| 2478 |
+
|
| 2479 |
+
run_hierarchical_button.click(
|
| 2480 |
+
run_fps_tsne_hierarchical,
|
| 2481 |
+
inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
|
| 2482 |
+
outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, big_tsne_plot],
|
| 2483 |
+
)
|
| 2484 |
gr.Markdown('---')
|
| 2485 |
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
| 2486 |
gr.Markdown('---')
|