Update app.py
Browse files
app.py
CHANGED
|
@@ -1,416 +1,425 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
| 3 |
-
# --------------------------------------------------------
|
| 4 |
-
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
| 5 |
-
# You can find the license in the LICENSE file in the root directory of this source tree.
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import os
|
| 10 |
-
os.system(f'pip install dlib')
|
| 11 |
-
import dlib
|
| 12 |
-
import argparse
|
| 13 |
-
import numpy as np
|
| 14 |
-
from PIL import Image
|
| 15 |
-
import cv2
|
| 16 |
-
import torch
|
| 17 |
-
from huggingface_hub import hf_hub_download
|
| 18 |
-
import gradio as gr
|
| 19 |
-
|
| 20 |
-
import models_vit
|
| 21 |
-
from util.datasets import build_dataset
|
| 22 |
-
from engine_finetune import test_two_class, test_multi_class
|
| 23 |
-
import matplotlib.pyplot as plt
|
| 24 |
-
from torchvision import transforms
|
| 25 |
-
import traceback
|
| 26 |
-
from pytorch_grad_cam import (
|
| 27 |
-
GradCAM,ScoreCAM,
|
| 28 |
-
XGradCAM, EigenCAM
|
| 29 |
-
)
|
| 30 |
-
from pytorch_grad_cam import GuidedBackpropReLUModel
|
| 31 |
-
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
|
| 32 |
-
|
| 33 |
-
def reshape_transform(tensor,height=14,width=14):
|
| 34 |
-
result = tensor[:, 1:, :].reshape(tensor.size(0),height,width,tensor.size(2))
|
| 35 |
-
result = result.transpose(2,3).transpose(1,2)
|
| 36 |
-
return result
|
| 37 |
-
|
| 38 |
-
def get_args_parser():
|
| 39 |
-
parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
|
| 40 |
-
parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
|
| 41 |
-
parser.add_argument('--epochs', default=50, type=int)
|
| 42 |
-
parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
|
| 43 |
-
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
|
| 44 |
-
parser.add_argument('--input_size', default=224, type=int, help='images input size')
|
| 45 |
-
parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
|
| 46 |
-
parser.set_defaults(normalize_from_IMN=True)
|
| 47 |
-
parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
|
| 48 |
-
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
|
| 49 |
-
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
|
| 50 |
-
parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
|
| 51 |
-
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
|
| 52 |
-
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
|
| 53 |
-
parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
|
| 54 |
-
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
|
| 55 |
-
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
|
| 56 |
-
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
|
| 57 |
-
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
|
| 58 |
-
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
|
| 59 |
-
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
|
| 60 |
-
parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
|
| 61 |
-
parser.add_argument('--recount', type=int, default=1, help='Random erase count')
|
| 62 |
-
parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
|
| 63 |
-
parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
|
| 64 |
-
parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
|
| 65 |
-
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
|
| 66 |
-
parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
|
| 67 |
-
parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
|
| 68 |
-
parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
|
| 69 |
-
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 70 |
-
parser.add_argument('--global_pool', action='store_true')
|
| 71 |
-
parser.set_defaults(global_pool=True)
|
| 72 |
-
parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
|
| 73 |
-
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
|
| 74 |
-
parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
|
| 75 |
-
parser.add_argument('--output_dir', default='', help='path where to save')
|
| 76 |
-
parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
|
| 77 |
-
parser.add_argument('--device', default='cuda', help='device to use for training / testing')
|
| 78 |
-
parser.add_argument('--seed', default=0, type=int)
|
| 79 |
-
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 80 |
-
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
|
| 81 |
-
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 82 |
-
parser.set_defaults(eval=True)
|
| 83 |
-
parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 84 |
-
parser.add_argument('--num_workers', default=10, type=int)
|
| 85 |
-
parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
|
| 86 |
-
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 87 |
-
parser.set_defaults(pin_mem=True)
|
| 88 |
-
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
| 89 |
-
parser.add_argument('--local_rank', default=-1, type=int)
|
| 90 |
-
parser.add_argument('--dist_on_itp', action='store_true')
|
| 91 |
-
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 92 |
-
return parser
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def load_model(select_skpt):
|
| 96 |
-
global ckpt, device, model, checkpoint
|
| 97 |
-
if select_skpt not in CKPT_NAME:
|
| 98 |
-
return gr.update(), "Select a correct model"
|
| 99 |
-
ckpt = select_skpt
|
| 100 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
-
args.nb_classes = CKPT_CLASS[ckpt]
|
| 102 |
-
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
| 103 |
-
num_classes=args.nb_classes,
|
| 104 |
-
drop_path_rate=args.drop_path,
|
| 105 |
-
global_pool=args.global_pool,
|
| 106 |
-
).to(device)
|
| 107 |
-
|
| 108 |
-
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
| 109 |
-
if os.path.isfile(args.resume) == False:
|
| 110 |
-
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
| 111 |
-
local_dir_use_symlinks=False,
|
| 112 |
-
repo_id='Wolowolo/fsfm-3c',
|
| 113 |
-
filename=CKPT_PATH[ckpt])
|
| 114 |
-
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
| 115 |
-
checkpoint = torch.load(args.resume, map_location=device)
|
| 116 |
-
model.load_state_dict(checkpoint['model'], strict=False)
|
| 117 |
-
model.eval()
|
| 118 |
-
global cam
|
| 119 |
-
cam = GradCAM(model = model,
|
| 120 |
-
target_layers=[model.blocks[-1].norm1],
|
| 121 |
-
reshape_transform=reshape_transform
|
| 122 |
-
)
|
| 123 |
-
return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def get_boundingbox(face, width, height, minsize=None):
|
| 127 |
-
x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
|
| 128 |
-
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
|
| 129 |
-
if minsize and size_bb < minsize:
|
| 130 |
-
size_bb = minsize
|
| 131 |
-
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
|
| 132 |
-
x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
|
| 133 |
-
size_bb = min(width - x1, size_bb)
|
| 134 |
-
size_bb = min(height - y1, size_bb)
|
| 135 |
-
return x1, y1, size_bb
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def extract_face(frame):
|
| 139 |
-
face_detector = dlib.get_frontal_face_detector()
|
| 140 |
-
image = np.array(frame.convert('RGB'))
|
| 141 |
-
faces = face_detector(image, 1)
|
| 142 |
-
if faces:
|
| 143 |
-
face = faces[0]
|
| 144 |
-
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
|
| 145 |
-
cropped_face = image[y:y + size, x:x + size]
|
| 146 |
-
return Image.fromarray(cropped_face)
|
| 147 |
-
return None
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
|
| 151 |
-
return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
|
| 155 |
-
video_capture = cv2.VideoCapture(src_video)
|
| 156 |
-
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 157 |
-
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
|
| 158 |
-
for frame_index in frame_indices:
|
| 159 |
-
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
| 160 |
-
ret, frame = video_capture.read()
|
| 161 |
-
if not ret:
|
| 162 |
-
continue
|
| 163 |
-
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 164 |
-
img = extract_face(image)
|
| 165 |
-
if img:
|
| 166 |
-
img = img.resize((224, 224), Image.BICUBIC)
|
| 167 |
-
save_img_name = f"frame_{frame_index}.png"
|
| 168 |
-
img.save(os.path.join(dst_path, '0', save_img_name))
|
| 169 |
-
video_capture.release()
|
| 170 |
-
return frame_indices
|
| 171 |
-
class TargetCategory:
|
| 172 |
-
def __init__(self, category_index):
|
| 173 |
-
self.category_index = category_index
|
| 174 |
-
|
| 175 |
-
def __call__(self, output):
|
| 176 |
-
return output[self.category_index]
|
| 177 |
-
def preprocess_image_cam(pil_img,mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]):
|
| 178 |
-
# 将 PIL 图像转换为 numpy 数组
|
| 179 |
-
img_np = np.array(pil_img)
|
| 180 |
-
|
| 181 |
-
# 归一化到 [0, 1]
|
| 182 |
-
img_np = img_np.astype(np.float32) / 255.0
|
| 183 |
-
|
| 184 |
-
# 标准化
|
| 185 |
-
img_np = (img_np - mean) / std
|
| 186 |
-
|
| 187 |
-
# 调整维度顺序以适应模型输入 (C, H, W)
|
| 188 |
-
img_np = np.transpose(img_np, (2, 0, 1))
|
| 189 |
-
|
| 190 |
-
# 添加批次维度 (B, C, H, W)
|
| 191 |
-
img_np = np.expand_dims(img_np, axis=0)
|
| 192 |
-
|
| 193 |
-
return img_np
|
| 194 |
-
def FSFM3C_image_detection(image):
|
| 195 |
-
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
|
| 196 |
-
os.makedirs(frame_path, exist_ok=True)
|
| 197 |
-
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
| 198 |
-
img = extract_face(image)
|
| 199 |
-
if img is None:
|
| 200 |
-
return 'No face detected, please upload a clear face!'
|
| 201 |
-
img = img.resize((224, 224), Image.BICUBIC)
|
| 202 |
-
img.save(os.path.join(frame_path, '0', "frame_0.png"))
|
| 203 |
-
args.data_path = frame_path
|
| 204 |
-
args.batch_size = 1
|
| 205 |
-
dataset_val = build_dataset(is_train=False, args=args)
|
| 206 |
-
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 207 |
-
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
|
| 208 |
-
|
| 209 |
-
if CKPT_CLASS[ckpt] > 2:
|
| 210 |
-
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
|
| 211 |
-
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
|
| 212 |
-
avg_video_pred = np.mean(video_pred_list, axis=0)
|
| 213 |
-
max_prob_index = np.argmax(avg_video_pred)
|
| 214 |
-
max_prob_class = class_names[max_prob_index]
|
| 215 |
-
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
|
| 216 |
-
image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
|
| 217 |
-
|
| 218 |
-
# Generate CAM heatmap for the detected class
|
| 219 |
-
use_cuda = True
|
| 220 |
-
input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 221 |
-
if use_cuda:
|
| 222 |
-
input_tensor = input_tensor.cuda()
|
| 223 |
-
|
| 224 |
-
# Dynamically determine the target category based on the maximum probability class
|
| 225 |
-
category_names_to_index = {
|
| 226 |
-
'Real or Bonafide': 0,
|
| 227 |
-
'Deepfake': 1,
|
| 228 |
-
'Diffusion or AIGC generated': 2,
|
| 229 |
-
'Spoofing or Presentation-attack': 3
|
| 230 |
-
}
|
| 231 |
-
target_category = TargetCategory(category_names_to_index[max_prob_class])
|
| 232 |
-
|
| 233 |
-
grayscale_cam = cam(input_tensor=input_tensor, targets=[target_category])
|
| 234 |
-
grayscale_cam = 1 - grayscale_cam[0, :]
|
| 235 |
-
img = np.array(img)
|
| 236 |
-
if img.shape[2] == 4:
|
| 237 |
-
img = img[:, :, :3]
|
| 238 |
-
img = img.astype(np.float32) / 255.0
|
| 239 |
-
visualization = show_cam_on_image(img, grayscale_cam)
|
| 240 |
-
visualization = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
|
| 241 |
-
|
| 242 |
-
# Add text overlay to the heatmap
|
| 243 |
-
# text = f"Detected: {max_prob_class}"
|
| 244 |
-
# cv2.putText(visualization, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
| 245 |
-
output_path = "./CAM_images/output_heatmap.png"
|
| 246 |
-
cv2.imwrite(output_path, visualization)
|
| 247 |
-
return image_results, output_path,probabilities[max_prob_index]
|
| 248 |
-
|
| 249 |
-
if CKPT_CLASS[ckpt] == 2:
|
| 250 |
-
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
|
| 251 |
-
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
|
| 252 |
-
prob = sum(video_pred_list) / len(video_pred_list)
|
| 253 |
-
label = "Deepfake" if prob <= 0.5 else "Real"
|
| 254 |
-
prob = prob if label == "Real" else 1 - prob
|
| 255 |
-
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
|
| 256 |
-
prob = sum(video_pred_list) / len(video_pred_list)
|
| 257 |
-
label = "Spoofing" if prob <= 0.5 else "Bonafide"
|
| 258 |
-
prob = prob if label == "Bonafide" else 1 - prob
|
| 259 |
-
image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
|
| 260 |
-
return image_results, None ,None
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
def FSFM3C_video_detection(video, num_frames):
|
| 264 |
-
try:
|
| 265 |
-
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
|
| 266 |
-
os.makedirs(frame_path, exist_ok=True)
|
| 267 |
-
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
| 268 |
-
frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
|
| 269 |
-
args.data_path = frame_path
|
| 270 |
-
args.batch_size = num_frames
|
| 271 |
-
dataset_val = build_dataset(is_train=False, args=args)
|
| 272 |
-
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 273 |
-
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
|
| 274 |
-
|
| 275 |
-
if CKPT_CLASS[ckpt] > 2:
|
| 276 |
-
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
|
| 277 |
-
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
|
| 278 |
-
avg_video_pred = np.mean(video_pred_list, axis=0)
|
| 279 |
-
max_prob_index = np.argmax(avg_video_pred)
|
| 280 |
-
max_prob_class = class_names[max_prob_index]
|
| 281 |
-
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
|
| 282 |
-
|
| 283 |
-
frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
|
| 284 |
-
video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
|
| 285 |
-
f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
|
| 286 |
-
return video_results
|
| 287 |
-
|
| 288 |
-
if CKPT_CLASS[ckpt] == 2:
|
| 289 |
-
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
|
| 290 |
-
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
|
| 291 |
-
prob = sum(video_pred_list) / len(video_pred_list)
|
| 292 |
-
label = "Deepfake" if prob <= 0.5 else "Real"
|
| 293 |
-
prob = prob if label == "Real" else 1 - prob
|
| 294 |
-
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
|
| 295 |
-
range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
|
| 296 |
-
range(len(frame_indices))}
|
| 297 |
-
|
| 298 |
-
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
|
| 299 |
-
prob = sum(video_pred_list) / len(video_pred_list)
|
| 300 |
-
label = "Spoofing" if prob <= 0.5 else "Bonafide"
|
| 301 |
-
prob = prob if label == "Bonafide" else 1 - prob
|
| 302 |
-
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
|
| 303 |
-
range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
|
| 304 |
-
range(len(frame_indices))}
|
| 305 |
-
|
| 306 |
-
video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
|
| 307 |
-
f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
|
| 308 |
-
return video_results
|
| 309 |
-
except Exception as e:
|
| 310 |
-
return f"Error occurred. Please provide a clear face video or reduce the number of frames."
|
| 311 |
-
|
| 312 |
-
# Paths and Constants
|
| 313 |
-
P = os.path.abspath(__file__)
|
| 314 |
-
FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
|
| 315 |
-
CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
|
| 316 |
-
os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
|
| 317 |
-
os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
|
| 318 |
-
CKPT_NAME = [
|
| 319 |
-
'✨Unified-detector_v1_Fine-tuned_on_4_classes',
|
| 320 |
-
'DfD-Checkpoint_Fine-tuned_on_FF++',
|
| 321 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO',
|
| 322 |
-
]
|
| 323 |
-
CKPT_PATH = {
|
| 324 |
-
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
|
| 325 |
-
'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
|
| 326 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
|
| 327 |
-
}
|
| 328 |
-
CKPT_CLASS = {
|
| 329 |
-
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
|
| 330 |
-
'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
|
| 331 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
|
| 332 |
-
}
|
| 333 |
-
CKPT_MODEL = {
|
| 334 |
-
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
|
| 335 |
-
'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
|
| 336 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
|
| 340 |
-
gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
|
| 341 |
-
gr.Markdown("<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
|
| 342 |
-
"<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> "
|
| 343 |
-
"<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: "
|
| 344 |
-
"1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1 including loading and prediction. <br>"
|
| 345 |
-
"<b>[V0.1] 2024/12-2025/02/21</b>: "
|
| 346 |
-
"Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
|
| 347 |
-
gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
|
| 348 |
-
|
| 349 |
-
with gr.Row():
|
| 350 |
-
ckpt_select_dropdown = gr.Dropdown(
|
| 351 |
-
label="Select the Model for Detection ⬇️",
|
| 352 |
-
elem_classes="custom-label",
|
| 353 |
-
choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
|
| 354 |
-
multiselect=False,
|
| 355 |
-
value='Choose Model Here 🖱️',
|
| 356 |
-
interactive=True,
|
| 357 |
-
)
|
| 358 |
-
model_loading_status = gr.Textbox(label="Model Loading Status")
|
| 359 |
-
with gr.Row():
|
| 360 |
-
with gr.Column(scale=5):
|
| 361 |
-
gr.Markdown("### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))")
|
| 362 |
-
image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
|
| 363 |
-
image_submit_btn = gr.Button("Submit")
|
| 364 |
-
output_results_image = gr.Textbox(label="Detection Result")
|
| 365 |
-
|
| 366 |
-
with gr.Row():
|
| 367 |
-
output_heatmap = gr.Image(label="Grad_CAM")
|
| 368 |
-
output_max_prob_class = gr.Textbox(label="Detected Class")
|
| 369 |
-
with gr.Column(scale=5):
|
| 370 |
-
gr.Markdown("### Video Detection")
|
| 371 |
-
video = gr.Video(label="Upload/Capture your video")
|
| 372 |
-
frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
|
| 373 |
-
video_submit_btn = gr.Button("Submit")
|
| 374 |
-
output_results_video = gr.Textbox(label="Detection Result")
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
args
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
demo.launch()
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
| 5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
os.system(f'pip install dlib')
|
| 11 |
+
import dlib
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import cv2
|
| 16 |
+
import torch
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
import gradio as gr
|
| 19 |
+
|
| 20 |
+
import models_vit
|
| 21 |
+
from util.datasets import build_dataset
|
| 22 |
+
from engine_finetune import test_two_class, test_multi_class
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
from torchvision import transforms
|
| 25 |
+
import traceback
|
| 26 |
+
from pytorch_grad_cam import (
|
| 27 |
+
GradCAM,ScoreCAM,
|
| 28 |
+
XGradCAM, EigenCAM
|
| 29 |
+
)
|
| 30 |
+
from pytorch_grad_cam import GuidedBackpropReLUModel
|
| 31 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
|
| 32 |
+
|
| 33 |
+
def reshape_transform(tensor,height=14,width=14):
|
| 34 |
+
result = tensor[:, 1:, :].reshape(tensor.size(0),height,width,tensor.size(2))
|
| 35 |
+
result = result.transpose(2,3).transpose(1,2)
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
def get_args_parser():
|
| 39 |
+
parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
|
| 40 |
+
parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
|
| 41 |
+
parser.add_argument('--epochs', default=50, type=int)
|
| 42 |
+
parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
|
| 43 |
+
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
|
| 44 |
+
parser.add_argument('--input_size', default=224, type=int, help='images input size')
|
| 45 |
+
parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
|
| 46 |
+
parser.set_defaults(normalize_from_IMN=True)
|
| 47 |
+
parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
|
| 48 |
+
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
|
| 49 |
+
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
|
| 50 |
+
parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
|
| 51 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
|
| 52 |
+
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
|
| 53 |
+
parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
|
| 54 |
+
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
|
| 55 |
+
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
|
| 56 |
+
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
|
| 57 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
|
| 58 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
|
| 59 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
|
| 60 |
+
parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
|
| 61 |
+
parser.add_argument('--recount', type=int, default=1, help='Random erase count')
|
| 62 |
+
parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
|
| 63 |
+
parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
|
| 64 |
+
parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
|
| 65 |
+
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
|
| 66 |
+
parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
|
| 67 |
+
parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
|
| 68 |
+
parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
|
| 69 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
| 70 |
+
parser.add_argument('--global_pool', action='store_true')
|
| 71 |
+
parser.set_defaults(global_pool=True)
|
| 72 |
+
parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
|
| 73 |
+
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
|
| 74 |
+
parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
|
| 75 |
+
parser.add_argument('--output_dir', default='', help='path where to save')
|
| 76 |
+
parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
|
| 77 |
+
parser.add_argument('--device', default='cuda', help='device to use for training / testing')
|
| 78 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 79 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 80 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
|
| 81 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 82 |
+
parser.set_defaults(eval=True)
|
| 83 |
+
parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
| 84 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 85 |
+
parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
|
| 86 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 87 |
+
parser.set_defaults(pin_mem=True)
|
| 88 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
| 89 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
| 90 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 91 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 92 |
+
return parser
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_model(select_skpt):
|
| 96 |
+
global ckpt, device, model, checkpoint
|
| 97 |
+
if select_skpt not in CKPT_NAME:
|
| 98 |
+
return gr.update(), "Select a correct model"
|
| 99 |
+
ckpt = select_skpt
|
| 100 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
+
args.nb_classes = CKPT_CLASS[ckpt]
|
| 102 |
+
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
| 103 |
+
num_classes=args.nb_classes,
|
| 104 |
+
drop_path_rate=args.drop_path,
|
| 105 |
+
global_pool=args.global_pool,
|
| 106 |
+
).to(device)
|
| 107 |
+
|
| 108 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
| 109 |
+
if os.path.isfile(args.resume) == False:
|
| 110 |
+
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
| 111 |
+
local_dir_use_symlinks=False,
|
| 112 |
+
repo_id='Wolowolo/fsfm-3c',
|
| 113 |
+
filename=CKPT_PATH[ckpt])
|
| 114 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
| 115 |
+
checkpoint = torch.load(args.resume, map_location=device)
|
| 116 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
| 117 |
+
model.eval()
|
| 118 |
+
global cam
|
| 119 |
+
cam = GradCAM(model = model,
|
| 120 |
+
target_layers=[model.blocks[-1].norm1],
|
| 121 |
+
reshape_transform=reshape_transform
|
| 122 |
+
)
|
| 123 |
+
return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_boundingbox(face, width, height, minsize=None):
|
| 127 |
+
x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
|
| 128 |
+
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
|
| 129 |
+
if minsize and size_bb < minsize:
|
| 130 |
+
size_bb = minsize
|
| 131 |
+
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
|
| 132 |
+
x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
|
| 133 |
+
size_bb = min(width - x1, size_bb)
|
| 134 |
+
size_bb = min(height - y1, size_bb)
|
| 135 |
+
return x1, y1, size_bb
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def extract_face(frame):
|
| 139 |
+
face_detector = dlib.get_frontal_face_detector()
|
| 140 |
+
image = np.array(frame.convert('RGB'))
|
| 141 |
+
faces = face_detector(image, 1)
|
| 142 |
+
if faces:
|
| 143 |
+
face = faces[0]
|
| 144 |
+
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
|
| 145 |
+
cropped_face = image[y:y + size, x:x + size]
|
| 146 |
+
return Image.fromarray(cropped_face)
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
|
| 151 |
+
return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
|
| 155 |
+
video_capture = cv2.VideoCapture(src_video)
|
| 156 |
+
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 157 |
+
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
|
| 158 |
+
for frame_index in frame_indices:
|
| 159 |
+
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
| 160 |
+
ret, frame = video_capture.read()
|
| 161 |
+
if not ret:
|
| 162 |
+
continue
|
| 163 |
+
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 164 |
+
img = extract_face(image)
|
| 165 |
+
if img:
|
| 166 |
+
img = img.resize((224, 224), Image.BICUBIC)
|
| 167 |
+
save_img_name = f"frame_{frame_index}.png"
|
| 168 |
+
img.save(os.path.join(dst_path, '0', save_img_name))
|
| 169 |
+
video_capture.release()
|
| 170 |
+
return frame_indices
|
| 171 |
+
class TargetCategory:
|
| 172 |
+
def __init__(self, category_index):
|
| 173 |
+
self.category_index = category_index
|
| 174 |
+
|
| 175 |
+
def __call__(self, output):
|
| 176 |
+
return output[self.category_index]
|
| 177 |
+
def preprocess_image_cam(pil_img,mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]):
|
| 178 |
+
# 将 PIL 图像转换为 numpy 数组
|
| 179 |
+
img_np = np.array(pil_img)
|
| 180 |
+
|
| 181 |
+
# 归一化到 [0, 1]
|
| 182 |
+
img_np = img_np.astype(np.float32) / 255.0
|
| 183 |
+
|
| 184 |
+
# 标准化
|
| 185 |
+
img_np = (img_np - mean) / std
|
| 186 |
+
|
| 187 |
+
# 调整维度顺序以适应模型输入 (C, H, W)
|
| 188 |
+
img_np = np.transpose(img_np, (2, 0, 1))
|
| 189 |
+
|
| 190 |
+
# 添加批次维度 (B, C, H, W)
|
| 191 |
+
img_np = np.expand_dims(img_np, axis=0)
|
| 192 |
+
|
| 193 |
+
return img_np
|
| 194 |
+
def FSFM3C_image_detection(image):
|
| 195 |
+
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
|
| 196 |
+
os.makedirs(frame_path, exist_ok=True)
|
| 197 |
+
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
| 198 |
+
img = extract_face(image)
|
| 199 |
+
if img is None:
|
| 200 |
+
return 'No face detected, please upload a clear face!'
|
| 201 |
+
img = img.resize((224, 224), Image.BICUBIC)
|
| 202 |
+
img.save(os.path.join(frame_path, '0', "frame_0.png"))
|
| 203 |
+
args.data_path = frame_path
|
| 204 |
+
args.batch_size = 1
|
| 205 |
+
dataset_val = build_dataset(is_train=False, args=args)
|
| 206 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 207 |
+
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
|
| 208 |
+
|
| 209 |
+
if CKPT_CLASS[ckpt] > 2:
|
| 210 |
+
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
|
| 211 |
+
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
|
| 212 |
+
avg_video_pred = np.mean(video_pred_list, axis=0)
|
| 213 |
+
max_prob_index = np.argmax(avg_video_pred)
|
| 214 |
+
max_prob_class = class_names[max_prob_index]
|
| 215 |
+
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
|
| 216 |
+
image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
|
| 217 |
+
|
| 218 |
+
# Generate CAM heatmap for the detected class
|
| 219 |
+
use_cuda = True
|
| 220 |
+
input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 221 |
+
if use_cuda:
|
| 222 |
+
input_tensor = input_tensor.cuda()
|
| 223 |
+
|
| 224 |
+
# Dynamically determine the target category based on the maximum probability class
|
| 225 |
+
category_names_to_index = {
|
| 226 |
+
'Real or Bonafide': 0,
|
| 227 |
+
'Deepfake': 1,
|
| 228 |
+
'Diffusion or AIGC generated': 2,
|
| 229 |
+
'Spoofing or Presentation-attack': 3
|
| 230 |
+
}
|
| 231 |
+
target_category = TargetCategory(category_names_to_index[max_prob_class])
|
| 232 |
+
|
| 233 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=[target_category])
|
| 234 |
+
grayscale_cam = 1 - grayscale_cam[0, :]
|
| 235 |
+
img = np.array(img)
|
| 236 |
+
if img.shape[2] == 4:
|
| 237 |
+
img = img[:, :, :3]
|
| 238 |
+
img = img.astype(np.float32) / 255.0
|
| 239 |
+
visualization = show_cam_on_image(img, grayscale_cam)
|
| 240 |
+
visualization = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
|
| 241 |
+
|
| 242 |
+
# Add text overlay to the heatmap
|
| 243 |
+
# text = f"Detected: {max_prob_class}"
|
| 244 |
+
# cv2.putText(visualization, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
| 245 |
+
output_path = "./CAM_images/output_heatmap.png"
|
| 246 |
+
cv2.imwrite(output_path, visualization)
|
| 247 |
+
return image_results, output_path,probabilities[max_prob_index]
|
| 248 |
+
|
| 249 |
+
if CKPT_CLASS[ckpt] == 2:
|
| 250 |
+
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
|
| 251 |
+
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
|
| 252 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
| 253 |
+
label = "Deepfake" if prob <= 0.5 else "Real"
|
| 254 |
+
prob = prob if label == "Real" else 1 - prob
|
| 255 |
+
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
|
| 256 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
| 257 |
+
label = "Spoofing" if prob <= 0.5 else "Bonafide"
|
| 258 |
+
prob = prob if label == "Bonafide" else 1 - prob
|
| 259 |
+
image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
|
| 260 |
+
return image_results, None ,None
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def FSFM3C_video_detection(video, num_frames):
|
| 264 |
+
try:
|
| 265 |
+
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
|
| 266 |
+
os.makedirs(frame_path, exist_ok=True)
|
| 267 |
+
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
| 268 |
+
frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
|
| 269 |
+
args.data_path = frame_path
|
| 270 |
+
args.batch_size = num_frames
|
| 271 |
+
dataset_val = build_dataset(is_train=False, args=args)
|
| 272 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 273 |
+
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
|
| 274 |
+
|
| 275 |
+
if CKPT_CLASS[ckpt] > 2:
|
| 276 |
+
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
|
| 277 |
+
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
|
| 278 |
+
avg_video_pred = np.mean(video_pred_list, axis=0)
|
| 279 |
+
max_prob_index = np.argmax(avg_video_pred)
|
| 280 |
+
max_prob_class = class_names[max_prob_index]
|
| 281 |
+
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
|
| 282 |
+
|
| 283 |
+
frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
|
| 284 |
+
video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
|
| 285 |
+
f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
|
| 286 |
+
return video_results
|
| 287 |
+
|
| 288 |
+
if CKPT_CLASS[ckpt] == 2:
|
| 289 |
+
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
|
| 290 |
+
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
|
| 291 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
| 292 |
+
label = "Deepfake" if prob <= 0.5 else "Real"
|
| 293 |
+
prob = prob if label == "Real" else 1 - prob
|
| 294 |
+
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
|
| 295 |
+
range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
|
| 296 |
+
range(len(frame_indices))}
|
| 297 |
+
|
| 298 |
+
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
|
| 299 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
| 300 |
+
label = "Spoofing" if prob <= 0.5 else "Bonafide"
|
| 301 |
+
prob = prob if label == "Bonafide" else 1 - prob
|
| 302 |
+
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
|
| 303 |
+
range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
|
| 304 |
+
range(len(frame_indices))}
|
| 305 |
+
|
| 306 |
+
video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
|
| 307 |
+
f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
|
| 308 |
+
return video_results
|
| 309 |
+
except Exception as e:
|
| 310 |
+
return f"Error occurred. Please provide a clear face video or reduce the number of frames."
|
| 311 |
+
|
| 312 |
+
# Paths and Constants
|
| 313 |
+
P = os.path.abspath(__file__)
|
| 314 |
+
FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
|
| 315 |
+
CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
|
| 316 |
+
os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
|
| 317 |
+
os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
|
| 318 |
+
CKPT_NAME = [
|
| 319 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes',
|
| 320 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++',
|
| 321 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO',
|
| 322 |
+
]
|
| 323 |
+
CKPT_PATH = {
|
| 324 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
|
| 325 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
|
| 326 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
|
| 327 |
+
}
|
| 328 |
+
CKPT_CLASS = {
|
| 329 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
|
| 330 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
|
| 331 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
|
| 332 |
+
}
|
| 333 |
+
CKPT_MODEL = {
|
| 334 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
|
| 335 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
|
| 336 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
|
| 340 |
+
gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
|
| 341 |
+
gr.Markdown("<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
|
| 342 |
+
"<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> "
|
| 343 |
+
"<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: "
|
| 344 |
+
"1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1 including loading and prediction. <br>"
|
| 345 |
+
"<b>[V0.1] 2024/12-2025/02/21</b>: "
|
| 346 |
+
"Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
|
| 347 |
+
gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
|
| 348 |
+
|
| 349 |
+
with gr.Row():
|
| 350 |
+
ckpt_select_dropdown = gr.Dropdown(
|
| 351 |
+
label="Select the Model for Detection ⬇️",
|
| 352 |
+
elem_classes="custom-label",
|
| 353 |
+
choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
|
| 354 |
+
multiselect=False,
|
| 355 |
+
value='Choose Model Here 🖱️',
|
| 356 |
+
interactive=True,
|
| 357 |
+
)
|
| 358 |
+
model_loading_status = gr.Textbox(label="Model Loading Status")
|
| 359 |
+
with gr.Row():
|
| 360 |
+
with gr.Column(scale=5):
|
| 361 |
+
gr.Markdown("### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))")
|
| 362 |
+
image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
|
| 363 |
+
image_submit_btn = gr.Button("Submit")
|
| 364 |
+
output_results_image = gr.Textbox(label="Detection Result")
|
| 365 |
+
|
| 366 |
+
with gr.Row():
|
| 367 |
+
output_heatmap = gr.Image(label="Grad_CAM")
|
| 368 |
+
output_max_prob_class = gr.Textbox(label="Detected Class")
|
| 369 |
+
with gr.Column(scale=5):
|
| 370 |
+
gr.Markdown("### Video Detection")
|
| 371 |
+
video = gr.Video(label="Upload/Capture your video")
|
| 372 |
+
frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
|
| 373 |
+
video_submit_btn = gr.Button("Submit")
|
| 374 |
+
output_results_video = gr.Textbox(label="Detection Result")
|
| 375 |
+
|
| 376 |
+
gr.HTML(
|
| 377 |
+
'<div style="display: flex; justify-content: center; gap: 20px; margin-bottom: 20px;">'
|
| 378 |
+
'<a href="https://mapmyvisitors.com/web/1bxvi" title="Visit tracker">'
|
| 379 |
+
'<img src="https://mapmyvisitors.com/map.png?d=FYhBoxLDEaFAxdfRzk5TuchYOBGrnSa98Ky59EkEEpY&cl=ffffff">'
|
| 380 |
+
'</a>'
|
| 381 |
+
'</div>'
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
ckpt_select_dropdown.change(
|
| 386 |
+
fn=load_model,
|
| 387 |
+
inputs=[ckpt_select_dropdown],
|
| 388 |
+
outputs=[ckpt_select_dropdown, model_loading_status],
|
| 389 |
+
)
|
| 390 |
+
image_submit_btn.click(
|
| 391 |
+
fn=FSFM3C_image_detection,
|
| 392 |
+
inputs=[image],
|
| 393 |
+
outputs=[output_results_image, output_heatmap,output_max_prob_class],
|
| 394 |
+
)
|
| 395 |
+
video_submit_btn.click(
|
| 396 |
+
fn=FSFM3C_video_detection,
|
| 397 |
+
inputs=[video, frame_slider],
|
| 398 |
+
outputs=[output_results_video],
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
args = get_args_parser()
|
| 403 |
+
args = args.parse_args()
|
| 404 |
+
ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
|
| 405 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 406 |
+
args.nb_classes = CKPT_CLASS[ckpt]
|
| 407 |
+
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
| 408 |
+
num_classes=args.nb_classes,
|
| 409 |
+
drop_path_rate=args.drop_path,
|
| 410 |
+
global_pool=args.global_pool,
|
| 411 |
+
).to(device)
|
| 412 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
| 413 |
+
if os.path.isfile(args.resume) == False:
|
| 414 |
+
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
| 415 |
+
local_dir_use_symlinks=False,
|
| 416 |
+
repo_id='Wolowolo/fsfm-3c',
|
| 417 |
+
filename=CKPT_PATH[ckpt])
|
| 418 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
| 419 |
+
checkpoint = torch.load(args.resume, map_location=device)
|
| 420 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
| 421 |
+
model.eval()
|
| 422 |
+
|
| 423 |
+
gr.close_all()
|
| 424 |
+
demo.queue()
|
| 425 |
demo.launch()
|