Spaces:
Sleeping
Sleeping
| import Predict | |
| import cv2 as cv | |
| import numpy as np | |
| import torch | |
| if __name__ == "__main__": | |
| input_ad_path = 'Demo/MUPI_1185x1750mm_v02.jpg' | |
| text_detection_model_path = 'EAST-Text-Detection/frozen_east_text_detection.pb' | |
| LDA_model_pth = 'LDA_Model_trained/lda_model_best_tot.model' | |
| training_ad_text_dictionary_path = 'LDA_Model_trained/object_word_dictionary' | |
| training_lang_preposition_path = 'LDA_Model_trained/dutch_preposition' | |
| # all_paths = ['Randomized_Dataset/Single_Page_Ads/ad'+str(i)+'.jpg' for i in range(37)] | |
| # all_paths = ['Randomized_Dataset/ads/'+str(i+1)+'ad.jpg' for i in range(24)] | |
| # all_paths = ['Randomized_Dataset/More_Backpage_Ads/Backpage_Imgs/ad'+str(i+1)+'.jpg' for i in range(123)] | |
| ed_indices = [1,2,3,4,5,7,8,9,10,11,13,14,15,16,17,19,20,21,22,23] | |
| all_paths = [f'Randomized_Dataset/eds/{i}ed.jpg' for i in ed_indices] | |
| Features10 = {} | |
| # ad = cv.imread('Randomized_Dataset/Single_Page_Ads/1ad.jpg') | |
| # print(ad.shape) | |
| for i,input_ad_path in enumerate(all_paths): | |
| print(input_ad_path) | |
| ad = cv.imread(input_ad_path) | |
| ad = cv.cvtColor(ad, cv.COLOR_BGR2RGB) | |
| ad = ad[89:921,320:960,:] | |
| ad = cv.resize(ad, (640, 832)) | |
| surfaces = list(torch.load('Randomized_Dataset/Single_Page_Ads/DATA/surfaces')[i]) | |
| prod_group = torch.load('Randomized_Dataset/Single_Page_Ads/DATA/Prod_Cat')[i] | |
| ad_topic = torch.load('Randomized_Dataset/Single_Page_Ads/DATA/embs_single_page_ads')[i] | |
| ctpg_topic = torch.load('Randomized_Dataset/Single_Page_Ads/DATA/embs_single_page_ads')[i] | |
| # surfaces = list(torch.load('Randomized_Dataset/surfaces')[i]) | |
| # prod_group = torch.load('Randomized_Dataset/Prod_Cat')[i] | |
| # ad_topic = torch.load('Randomized_Dataset/embs_randomized.pt')[i] | |
| # ctpg_topic = torch.load('Randomized_Dataset/embs_randomized.pt')[i] | |
| # surfaces = list(torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/surface_sizes_data')[i]) | |
| # prod_group = torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/Prod_Cat')[i] | |
| # ad_topic = torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/embs_backpage_ads')[i] | |
| # ctpg_topic = torch.load('Randomized_Dataset/More_Backpage_Ads/DATA/embs_backpage_ads')[i] | |
| Features = Predict.Ad_Gaze_Prediction(input_ad_path=ad, input_ctpg_path=None, ad_location=None, | |
| text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth, | |
| training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch', | |
| Ad_var=None, Ctpg_var=None, | |
| flag_full_page_ad=False, | |
| ad_embeddings=ad_topic, ctpg_embeddings=ctpg_topic, | |
| surface_sizes=surfaces, Product_Group=prod_group, | |
| obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type='Ad',Ad_Features_Only=True, Info_printing=False) | |
| Features10[ed_indices[i]-1] = Features | |
| # torch.save(Features10, 'Randomized_Dataset/Single_Page_Ads/DATA/Features10.pt') | |
| # torch.save(Features10, 'Randomized_Dataset/More_Backpage_Ads/DATA/Features10.pt') | |
| torch.save(Features10, 'Randomized_Dataset/eds/DATA/Features10.pt') | |
| # ad = cv.imread(input_ad_path) | |
| # ad = cv.cvtColor(ad, cv.COLOR_BGR2RGB) | |
| # ad = cv.resize(ad, (640, 832)) | |
| # surfaces = list(np.load('surfaces.npy')) | |
| # prod_group = np.load('prod_group.npy') | |
| # ad_topic = torch.load('new_ad_topic.pt') | |
| # ctpg_topic = torch.load('new_ctxt_topic.pt') | |
| # Gaze = Predict.Ad_Gaze_Prediction(input_ad_path=ad, input_ctpg_path=None, ad_location=None, | |
| # text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth, | |
| # training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch', | |
| # ad_embeddings=ad_topic, ctpg_embeddings=ctpg_topic, | |
| # surface_sizes=surfaces, Product_Group=prod_group, | |
| # obj_detection_model_pth=None, num_topic=20, Gaze_Time_Type='BS') |