XGBoost_Gaze / Demo.py
Demo750's picture
Upload folder using huggingface_hub
bff56b8 verified
raw
history blame
4.34 kB
import Predict
import cv2 as cv
import numpy as np
import torch
if __name__ == "__main__":
input_ad_path = 'Demo/MUPI_1185x1750mm_v02.jpg'
text_detection_model_path = 'EAST-Text-Detection/frozen_east_text_detection.pb'
LDA_model_pth = 'LDA_Model_trained/lda_model_best_tot.model'
training_ad_text_dictionary_path = 'LDA_Model_trained/object_word_dictionary'
training_lang_preposition_path = 'LDA_Model_trained/dutch_preposition'
# all_paths = ['Randomized_Dataset/Single_Page_Ads/ad'+str(i)+'.jpg' for i in range(37)]
# all_paths = ['Randomized_Dataset/ads/'+str(i+1)+'ad.jpg' for i in range(24)]
# all_paths = ['Randomized_Dataset/More_Backpage_Ads/Backpage_Imgs/ad'+str(i+1)+'.jpg' for i in range(123)]
ed_indices = [1,2,3,4,5,7,8,9,10,11,13,14,15,16,17,19,20,21,22,23]
all_paths = [f'Randomized_Dataset/eds/{i}ed.jpg' for i in ed_indices]
Features10 = {}
# ad = cv.imread('Randomized_Dataset/Single_Page_Ads/1ad.jpg')
# print(ad.shape)
for i,input_ad_path in enumerate(all_paths):
print(input_ad_path)
ad = cv.imread(input_ad_path)
ad = cv.cvtColor(ad, cv.COLOR_BGR2RGB)
ad = ad[89:921,320:960,:]
ad = cv.resize(ad, (640, 832))
surfaces = list(torch.load('Randomized_Dataset/Single_Page_Ads/DATA/surfaces')[i])
prod_group = torch.load('Randomized_Dataset/Single_Page_Ads/DATA/Prod_Cat')[i]
ad_topic = torch.load('Randomized_Dataset/Single_Page_Ads/DATA/embs_single_page_ads')[i]
ctpg_topic = torch.load('Randomized_Dataset/Single_Page_Ads/DATA/embs_single_page_ads')[i]
# surfaces = list(torch.load('Randomized_Dataset/surfaces')[i])
# prod_group = torch.load('Randomized_Dataset/Prod_Cat')[i]
# ad_topic = torch.load('Randomized_Dataset/embs_randomized.pt')[i]
# ctpg_topic = torch.load('Randomized_Dataset/embs_randomized.pt')[i]
# surfaces = list(torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/surface_sizes_data')[i])
# prod_group = torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/Prod_Cat')[i]
# ad_topic = torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/embs_backpage_ads')[i]
# ctpg_topic = torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/embs_backpage_ads')[i]
Features = Predict.Ad_Gaze_Prediction(input_ad_path=ad, input_ctpg_path=None, ad_location=None,
text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth,
training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch',
Ad_var=None, Ctpg_var=None,
flag_full_page_ad=False,
ad_embeddings=ad_topic, ctpg_embeddings=ctpg_topic,
surface_sizes=surfaces, Product_Group=prod_group,
obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type='Ad',Ad_Features_Only=True, Info_printing=False)
Features10[ed_indices[i]-1] = Features
# torch.save(Features10, 'Randomized_Dataset/Single_Page_Ads/DATA/Features10.pt')
# torch.save(Features10, 'Randomized_Dataset/More_Backpage_Ads/DATA/Features10.pt')
torch.save(Features10, 'Randomized_Dataset/eds/DATA/Features10.pt')
# ad = cv.imread(input_ad_path)
# ad = cv.cvtColor(ad, cv.COLOR_BGR2RGB)
# ad = cv.resize(ad, (640, 832))
# surfaces = list(np.load('surfaces.npy'))
# prod_group = np.load('prod_group.npy')
# ad_topic = torch.load('new_ad_topic.pt')
# ctpg_topic = torch.load('new_ctxt_topic.pt')
# Gaze = Predict.Ad_Gaze_Prediction(input_ad_path=ad, input_ctpg_path=None, ad_location=None,
# text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth,
# training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch',
# ad_embeddings=ad_topic, ctpg_embeddings=ctpg_topic,
# surface_sizes=surfaces, Product_Group=prod_group,
# obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type='BS')