LuJingyi-John commited on
Commit
722e880
·
1 Parent(s): 11c0865

Fix Gradio compatibility issues and canvas processing bugs

Browse files
Files changed (2) hide show
  1. app.py +44 -5
  2. utils/ui_utils.py +108 -20
app.py CHANGED
@@ -21,14 +21,37 @@ def create_interface():
21
  # Draw Region Column
22
  with gr.Column():
23
  gr.Markdown("""<p style="text-align: center; font-size: 20px">1. Draw Regions</p>""")
24
- canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  with gr.Row():
26
  fit_btn = gr.Button("Resize Image")
27
 
28
  # Control Points Column
29
  with gr.Column():
30
  gr.Markdown("""<p style="text-align: center; font-size: 20px">2. Control Points</p>""")
31
- input_img = gr.Image(type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=True)
 
 
 
 
 
 
32
  with gr.Row():
33
  undo_btn = gr.Button("Undo Point")
34
  clear_btn = gr.Button("Clear Points")
@@ -36,14 +59,27 @@ def create_interface():
36
  # Results Column
37
  with gr.Column():
38
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""")
39
- output_img = gr.Image(type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=False)
 
 
 
 
 
 
40
  with gr.Row():
41
  run_btn = gr.Button("Inpaint")
42
  reset_btn = gr.Button("Reset All")
43
 
44
  # Generation Parameters
45
  with gr.Row():
46
- inpaint_ks = gr.Slider(minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True)
 
 
 
 
 
 
 
47
 
48
  setup_events(
49
  components={
@@ -95,7 +131,10 @@ def setup_events(components, state, buttons):
95
 
96
  # Canvas interaction events
97
  def setup_canvas_events():
98
- components['canvas'].edit(
 
 
 
99
  visualize_user_drag,
100
  [components['canvas'], state['points_list']],
101
  [components['input_img']]
 
21
  # Draw Region Column
22
  with gr.Column():
23
  gr.Markdown("""<p style="text-align: center; font-size: 20px">1. Draw Regions</p>""")
24
+ # Use ImageEditor for newer Gradio versions, fallback to Image with brush
25
+ try:
26
+ canvas = gr.ImageEditor(
27
+ label=" ",
28
+ height=CANVAS_SIZE,
29
+ width=CANVAS_SIZE,
30
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")
31
+ )
32
+ except:
33
+ # Fallback for older Gradio versions
34
+ canvas = gr.Image(
35
+ type="numpy",
36
+ label=" ",
37
+ height=CANVAS_SIZE,
38
+ width=CANVAS_SIZE,
39
+ sources=["upload", "webcam", "clipboard"]
40
+ )
41
+
42
  with gr.Row():
43
  fit_btn = gr.Button("Resize Image")
44
 
45
  # Control Points Column
46
  with gr.Column():
47
  gr.Markdown("""<p style="text-align: center; font-size: 20px">2. Control Points</p>""")
48
+ input_img = gr.Image(
49
+ type="numpy",
50
+ label=" ",
51
+ height=CANVAS_SIZE,
52
+ width=CANVAS_SIZE,
53
+ interactive=True
54
+ )
55
  with gr.Row():
56
  undo_btn = gr.Button("Undo Point")
57
  clear_btn = gr.Button("Clear Points")
 
59
  # Results Column
60
  with gr.Column():
61
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""")
62
+ output_img = gr.Image(
63
+ type="numpy",
64
+ label=" ",
65
+ height=CANVAS_SIZE,
66
+ width=CANVAS_SIZE,
67
+ interactive=False
68
+ )
69
  with gr.Row():
70
  run_btn = gr.Button("Inpaint")
71
  reset_btn = gr.Button("Reset All")
72
 
73
  # Generation Parameters
74
  with gr.Row():
75
+ inpaint_ks = gr.Slider(
76
+ minimum=0,
77
+ maximum=25,
78
+ value=5,
79
+ step=1,
80
+ label='How much to expand inpainting mask',
81
+ interactive=True
82
+ )
83
 
84
  setup_events(
85
  components={
 
131
 
132
  # Canvas interaction events
133
  def setup_canvas_events():
134
+ # Handle both ImageEditor and Image events
135
+ canvas_event = components['canvas'].change if hasattr(components['canvas'], 'change') else components['canvas'].edit
136
+
137
+ canvas_event(
138
  visualize_user_drag,
139
  [components['canvas'], state['points_list']],
140
  [components['input_img']]
utils/ui_utils.py CHANGED
@@ -25,14 +25,18 @@ pipe = None
25
  # UI functions
26
  def clear_all(length):
27
  """Reset UI by clearing all input images and parameters."""
28
- return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 2, None)
29
 
30
  def resize(canvas, gen_length, canvas_length):
31
  """Resize canvas while maintaining aspect ratio."""
32
  if not canvas:
33
  return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
34
 
35
- image = process_canvas(canvas)[0]
 
 
 
 
36
  aspect_ratio = image.shape[1] / image.shape[0]
37
  is_landscape = aspect_ratio >= 1
38
 
@@ -49,8 +53,61 @@ def resize(canvas, gen_length, canvas_length):
49
 
50
  def process_canvas(canvas):
51
  """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object."""
52
- image = canvas["image"].copy()
53
- mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return image, mask
55
 
56
  # Point manipulation functions
@@ -82,22 +139,37 @@ def visualize_user_drag(canvas, points):
82
  if canvas is None:
83
  return None
84
 
85
- image, mask = process_canvas(canvas)
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Apply colored mask overlay
88
- result = image.copy()
89
- result[mask == 1] = [255, 0, 0] # Red color
90
- image = cv2.addWeighted(result, 0.3, image, 0.7, 0)
 
91
 
92
  # Draw mask outline
93
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
94
- cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
 
95
 
96
  # Draw control points and motion vectors
 
97
  for idx, point in enumerate(points, 1):
98
  if idx % 2 == 0:
99
  cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point
100
- cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5)
 
101
  else:
102
  cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
103
  prev_point = point
@@ -109,7 +181,19 @@ def preview_out_image(canvas, points, inpaint_ks):
109
  if canvas is None:
110
  return None, None
111
 
112
- image, mask = process_canvas(canvas)
 
 
 
 
 
 
 
 
 
 
 
 
113
  if len(points) < 2:
114
  return image, None
115
 
@@ -120,15 +204,19 @@ def preview_out_image(canvas, points, inpaint_ks):
120
  gr.Warning('Click Resize Image Button first.')
121
  return image, None
122
 
123
- handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
124
- image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
 
125
 
126
- # Add grid pattern to highlight inpainting regions
127
- background = np.ones_like(mask) * 255
128
- background[::10] = background[:, ::10] = 0
129
- image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
130
-
131
- return image, (inpaint_mask * 255).astype(np.uint8)
 
 
 
132
 
133
  # Inpaint tools
134
  def setup_pipeline(device='cuda', model_version='v1-5'):
 
25
  # UI functions
26
  def clear_all(length):
27
  """Reset UI by clearing all input images and parameters."""
28
+ return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 5, None)
29
 
30
  def resize(canvas, gen_length, canvas_length):
31
  """Resize canvas while maintaining aspect ratio."""
32
  if not canvas:
33
  return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
34
 
35
+ result = process_canvas(canvas)
36
+ if result[0] is None: # Check if image is None
37
+ return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
38
+
39
+ image = result[0]
40
  aspect_ratio = image.shape[1] / image.shape[0]
41
  is_landscape = aspect_ratio >= 1
42
 
 
53
 
54
  def process_canvas(canvas):
55
  """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object."""
56
+ # Handle None canvas
57
+ if canvas is None:
58
+ return None, None
59
+
60
+ # Handle new ImageEditor format
61
+ if isinstance(canvas, dict):
62
+ if 'background' in canvas and 'layers' in canvas:
63
+ # New ImageEditor format
64
+ if canvas["background"] is None:
65
+ return None, None
66
+ image = canvas["background"].copy()
67
+
68
+ # Ensure image is 3-channel RGB
69
+ if len(image.shape) == 3 and image.shape[2] == 4:
70
+ image = image[:, :, :3] # Remove alpha channel
71
+ elif len(image.shape) == 2:
72
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
73
+
74
+ # Try to extract mask from layers
75
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
76
+ if canvas["layers"]:
77
+ for layer in canvas["layers"]:
78
+ if isinstance(layer, np.ndarray) and len(layer.shape) >= 2:
79
+ layer_mask = np.uint8(layer[:, :, 0] > 0) if len(layer.shape) == 3 else np.uint8(layer > 0)
80
+ mask = np.logical_or(mask, layer_mask).astype(np.uint8)
81
+ elif 'image' in canvas and 'mask' in canvas:
82
+ # Old format
83
+ if canvas["image"] is None:
84
+ return None, None
85
+ image = canvas["image"].copy()
86
+
87
+ # Ensure image is 3-channel RGB
88
+ if len(image.shape) == 3 and image.shape[2] == 4:
89
+ image = image[:, :, :3] # Remove alpha channel
90
+ elif len(image.shape) == 2:
91
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
92
+
93
+ mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy() if canvas["mask"] is not None else np.zeros(image.shape[:2], dtype=np.uint8)
94
+ else:
95
+ # Fallback
96
+ return None, None
97
+ else:
98
+ # Direct numpy array
99
+ if canvas is None:
100
+ return None, None
101
+ image = canvas.copy() if isinstance(canvas, np.ndarray) else np.array(canvas)
102
+
103
+ # Ensure image is 3-channel RGB
104
+ if len(image.shape) == 3 and image.shape[2] == 4:
105
+ image = image[:, :, :3] # Remove alpha channel
106
+ elif len(image.shape) == 2:
107
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
108
+
109
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
110
+
111
  return image, mask
112
 
113
  # Point manipulation functions
 
139
  if canvas is None:
140
  return None
141
 
142
+ result = process_canvas(canvas)
143
+ if result[0] is None: # Check if image is None
144
+ return None
145
+
146
+ image, mask = result
147
+
148
+ # Ensure image is uint8 and 3-channel
149
+ if image.dtype != np.uint8:
150
+ image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
151
+
152
+ if len(image.shape) != 3 or image.shape[2] != 3:
153
+ return None
154
 
155
  # Apply colored mask overlay
156
+ result_img = image.copy()
157
+ if np.any(mask == 1):
158
+ result_img[mask == 1] = [255, 0, 0] # Red color
159
+ image = cv2.addWeighted(result_img, 0.3, image, 0.7, 0)
160
 
161
  # Draw mask outline
162
+ if np.any(mask > 0):
163
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
164
+ cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
165
 
166
  # Draw control points and motion vectors
167
+ prev_point = None
168
  for idx, point in enumerate(points, 1):
169
  if idx % 2 == 0:
170
  cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point
171
+ if prev_point is not None:
172
+ cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5)
173
  else:
174
  cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
175
  prev_point = point
 
181
  if canvas is None:
182
  return None, None
183
 
184
+ result = process_canvas(canvas)
185
+ if result[0] is None: # Check if image is None
186
+ return None, None
187
+
188
+ image, mask = result
189
+
190
+ # Ensure image is uint8 and 3-channel
191
+ if image.dtype != np.uint8:
192
+ image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
193
+
194
+ if len(image.shape) != 3 or image.shape[2] != 3:
195
+ return image, None
196
+
197
  if len(points) < 2:
198
  return image, None
199
 
 
204
  gr.Warning('Click Resize Image Button first.')
205
  return image, None
206
 
207
+ try:
208
+ handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
209
+ image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
210
 
211
+ # Add grid pattern to highlight inpainting regions
212
+ background = np.ones_like(mask) * 255
213
+ background[::10] = background[:, ::10] = 0
214
+ image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
215
+
216
+ return image, (inpaint_mask * 255).astype(np.uint8)
217
+ except Exception as e:
218
+ gr.Warning(f"Preview failed: {str(e)}")
219
+ return image, None
220
 
221
  # Inpaint tools
222
  def setup_pipeline(device='cuda', model_version='v1-5'):