Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import os.path as osp | |
| import math | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| from plyfile import PlyData, PlyElement | |
| import numpy as np | |
| def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000): | |
| """ | |
| Loads images from a directory or video, resizes them to a uniform size, | |
| then converts and stacks them into a single [N, 3, H, W] PyTorch tensor. | |
| """ | |
| sources = [] | |
| # --- 1. Load image paths or video frames --- | |
| if osp.isdir(path): | |
| print(f"Loading images from directory: {path}") | |
| filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]) | |
| for i in range(0, len(filenames), interval): | |
| img_path = osp.join(path, filenames[i]) | |
| try: | |
| sources.append(Image.open(img_path).convert('RGB')) | |
| except Exception as e: | |
| print(f"Could not load image {filenames[i]}: {e}") | |
| elif path.lower().endswith('.mp4'): | |
| print(f"Loading frames from video: {path}") | |
| cap = cv2.VideoCapture(path) | |
| if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}") | |
| frame_idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: break | |
| if frame_idx % interval == 0: | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| sources.append(Image.fromarray(rgb_frame)) | |
| frame_idx += 1 | |
| cap.release() | |
| else: | |
| raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}") | |
| if not sources: | |
| print("No images found or loaded.") | |
| return torch.empty(0) | |
| print(f"Found {len(sources)} images/frames. Processing...") | |
| # --- 2. Determine a uniform target size for all images based on the first image --- | |
| # This is necessary to ensure all tensors have the same dimensions for stacking. | |
| first_img = sources[0] | |
| W_orig, H_orig = first_img.size | |
| scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1 | |
| W_target, H_target = W_orig * scale, H_orig * scale | |
| k, m = round(W_target / 14), round(H_target / 14) | |
| while (k * 14) * (m * 14) > PIXEL_LIMIT: | |
| if k / m > W_target / H_target: k -= 1 | |
| else: m -= 1 | |
| TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14 | |
| print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})") | |
| # --- 3. Resize images and convert them to tensors in the [0, 1] range --- | |
| tensor_list = [] | |
| # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1] | |
| to_tensor_transform = transforms.ToTensor() | |
| for img_pil in sources: | |
| try: | |
| # Resize to the uniform target size | |
| resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS) | |
| # Convert to tensor | |
| img_tensor = to_tensor_transform(resized_img) | |
| tensor_list.append(img_tensor) | |
| except Exception as e: | |
| print(f"Error processing an image: {e}") | |
| if not tensor_list: | |
| print("No images were successfully processed.") | |
| return torch.empty(0) | |
| # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor --- | |
| return torch.stack(tensor_list, dim=0) | |
| def tensor_to_pil(tensor): | |
| """ | |
| Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension | |
| (if it has size 3) to the last axis before converting. | |
| Args: | |
| tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W]. | |
| Returns: | |
| PIL.Image: The converted PIL image. | |
| """ | |
| if torch.is_tensor(tensor): | |
| array = tensor.detach().cpu().numpy() | |
| else: | |
| array = tensor | |
| return array_to_pil(array) | |
| def array_to_pil(array): | |
| """ | |
| Converts a NumPy array to a PIL image. Automatically: | |
| - Squeezes dimensions of size 1. | |
| - Moves the channel dimension (if it has size 3) to the last axis. | |
| Args: | |
| array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W]. | |
| Returns: | |
| PIL.Image: The converted PIL image. | |
| """ | |
| # Remove singleton dimensions | |
| array = np.squeeze(array) | |
| # Ensure the array has the channel dimension as the last axis | |
| if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis | |
| array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis | |
| # Handle single-channel grayscale images | |
| if array.ndim == 2: # [H, W] | |
| return Image.fromarray((array * 255).astype(np.uint8), mode="L") | |
| elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels | |
| return Image.fromarray((array * 255).astype(np.uint8), mode="RGB") | |
| else: | |
| raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}") | |
| def rotate_target_dim_to_last_axis(x, target_dim=3): | |
| shape = x.shape | |
| axis_to_move = -1 | |
| # Iterate backwards to find the first occurrence from the end | |
| # (which corresponds to the last dimension of size 3 in the original order). | |
| for i in range(len(shape) - 1, -1, -1): | |
| if shape[i] == target_dim: | |
| axis_to_move = i | |
| break | |
| # 2. If the axis is found and it's not already in the last position, move it. | |
| if axis_to_move != -1 and axis_to_move != len(shape) - 1: | |
| # Create the new dimension order. | |
| dims_order = list(range(len(shape))) | |
| dims_order.pop(axis_to_move) | |
| dims_order.append(axis_to_move) | |
| # Use permute to reorder the dimensions. | |
| ret = x.transpose(*dims_order) | |
| else: | |
| ret = x | |
| return ret | |
| def write_ply( | |
| xyz, | |
| rgb=None, | |
| path='output.ply', | |
| ) -> None: | |
| if torch.is_tensor(xyz): | |
| xyz = xyz.detach().cpu().numpy() | |
| if torch.is_tensor(rgb): | |
| rgb = rgb.detach().cpu().numpy() | |
| if rgb is not None and rgb.max() > 1: | |
| rgb = rgb / 255. | |
| xyz = rotate_target_dim_to_last_axis(xyz, 3) | |
| xyz = xyz.reshape(-1, 3) | |
| if rgb is not None: | |
| rgb = rotate_target_dim_to_last_axis(rgb, 3) | |
| rgb = rgb.reshape(-1, 3) | |
| if rgb is None: | |
| min_coord = np.min(xyz, axis=0) | |
| max_coord = np.max(xyz, axis=0) | |
| normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8) | |
| hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2] | |
| hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1) | |
| c = hsv[:,2:] * hsv[:,1:2] | |
| x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 )) | |
| m = hsv[:,2:] - c | |
| rgb = np.zeros_like(hsv) | |
| cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1) | |
| rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])]) | |
| cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2) | |
| rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])]) | |
| cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3) | |
| rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]]) | |
| cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4) | |
| rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]]) | |
| cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5) | |
| rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]]) | |
| cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6) | |
| rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]]) | |
| rgb = (rgb + m) | |
| dtype = [ | |
| ("x", "f4"), | |
| ("y", "f4"), | |
| ("z", "f4"), | |
| ("nx", "f4"), | |
| ("ny", "f4"), | |
| ("nz", "f4"), | |
| ("red", "u1"), | |
| ("green", "u1"), | |
| ("blue", "u1"), | |
| ] | |
| normals = np.zeros_like(xyz) | |
| elements = np.empty(xyz.shape[0], dtype=dtype) | |
| attributes = np.concatenate((xyz, normals, rgb * 255), axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| vertex_element = PlyElement.describe(elements, "vertex") | |
| ply_data = PlyData([vertex_element]) | |
| ply_data.write(path) |