From e8ed13a23b7a95cec413d3d1db1db1dd7f1753ab Mon Sep 17 00:00:00 2001 From: Brandon Date: Mon, 20 Apr 2026 21:14:06 -0700 Subject: [PATCH] feat: implement image search, scraping, and balanced selection pipeline using ResNet50 and caption relevance scoring --- .gitignore | 3 +- dataset/data.yaml | 2 +- dataset/train/labels.cache | Bin 7482 -> 7932 bytes src/download_images.py | 19 +++-- src/main.py | 21 ++++- src/scrape_similar.py | 13 ++- src/search_images.py | 38 ++++++--- src/select_balanced_images.py | 149 +++++++++++++++++++++++---------- src/utils/caption_relevance.py | 135 +++++++++++++++++++++++++++++ test_caption_relevance.py | 74 ++++++++++++++++ 10 files changed, 385 insertions(+), 69 deletions(-) create mode 100644 src/utils/caption_relevance.py create mode 100644 test_caption_relevance.py diff --git a/.gitignore b/.gitignore index 4018537..3450029 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ dataset/train/labels app_logs.txt .aider* *.tex -runs/ \ No newline at end of file +runs/ +*.md \ No newline at end of file diff --git a/dataset/data.yaml b/dataset/data.yaml index 17f9731..42ad750 100644 --- a/dataset/data.yaml +++ b/dataset/data.yaml @@ -2,4 +2,4 @@ path: C:\Users\Brandon Shen\Documents\SearchVision\dataset train: C:\Users\Brandon Shen\Documents\SearchVision\dataset\train\images val: C:\Users\Brandon Shen\Documents\SearchVision\dataset\train\images names: - 0: soccer ball + 0: coffee cup diff --git a/dataset/train/labels.cache b/dataset/train/labels.cache index 721bc7d34f9f402f5eca30e7ddfc5ea246c625c0..6d84a213dc3f24e19ff790fd43d940df52880e35 100644 GIT binary patch literal 7932 zcmb`Md3;mF7ROUsmWBdV77+yks4a`7I~3`Rf`|}X>abX`71HFksdk%_DYQTp6#_*T z5kcGq5f>1_4N(yV6=ZQk+!a(nMHEyPpA~rL%x!LNKJU+$^5M*K&di+eZ*K2QnkD(; z^9pW^aZGVcPx2W7FC-0dCCwOQ<&+LH%30+zy)wq_m@Qij zk13I{b7Y*5EtB0beVc=Ps=mf>%2tIf91h2r!LoJZ+&Qw1knzdx*uog>n<1U~`T3*I z;g2?Mu4LPi#$uV!AloGuw!)AKPl-_xknM|QhvWvCm|WNzJIgC4m6lf+^1Q-4Yz-e$ zG(KQJplBGXbZ>Xh4$4l2vDoFsG&=hoeofrPkGK~I*%fi!E!;6d zd76k zu&1tms$X7S+aObt{n1rEO)Edp->i44-_kja=+lMFAbR{4l$k_N={RJTlG!DGOQ}H= zJ6OmZGZvK^La~YC&Wz>`$4u+Lyh6#LC4Lkm*WWMS-+!)ufS<4!Vddk*(R!fL%qNN? z*E09=pu957brt10OvvFD*VSRJ5lUWT;~Gh^*9tkx;u@{Fw))Sq`o+ckaGQEIu&I_w zx02UsuDl4Y{HSx8`!z;$EepzmFxT~zs}ORm#WgO>HD1XHHm(~ewou3$Ev|{0D`{3X zFHN1uk7u7|NA~!cyh+I-%~c%1b#v6Y()DvVDa187%;lk6ON8`VT)r@up=7CztBhj( zLY7-xw`i_*xL)I!e#l|$1;_8H~H{1Odm=QtH5Osn~O>kR?;Px=VOv>yIA@8&Z z?rP3#mM#C;6crS5j+Ou1=BP%U|GbF&?~Z!@Svvpuq5Kzw^Iu4rEfVq`%1oami)jbK zN5j2JF0ql^NAXLATxQ1OPTWuNuD5Dia_3hk*yDp8yrFI)H!@rC4cDJy+59YftqAjz&DN|BYu5XN@{us>qm=t&LOxEp$I{cbf^tuww|VGCu2k{~8}BL_ z^Q4fg%`wRP6pcyUcyuO9+|ZZ%f8NMq=I&)(*LUT~!w)fcpKkn<9XId^4?WKoP8r7= z>({aQCss1~w35$g{%0fj*F>Fvkmg?-+Uw`S{OgG2c_G(Zd%YoYuU}B|MH~A|6#uf2 zuULD%QSbGFOOyGWIWCs4woh)SxMR8d>Kb!hNd^45t);mu$})20t4eOttgl6|z8-bf z!J73A&59@7<}mAf z6Y?{Q>vPSOH|j9!GNvWZKh}b;Zs1&gq2!mE>#GQ^uZb&~S6uUe`v$S}#rE4U)_0WZ zejyK7Sl@@Seo*p98`eRJJtX8|3+pG1HF?pE)+LKBT=Hinf6-V+BCvjq_I9P47pq7FjqRb{a@=P?}+^W6ZQPf z3)MfN{F}o0|4W&j5%R2+|2Z==aNsAXFapG26JwzTmI7kYhFBnuTpAi$VwH%X74_L} zztEaSMEL@ouAet(Lu;8^1MyT3UGKq3Elj#CmI4ychA^QWxy+>7W7S65fjl@U5laE* zp|wcQCzm;>BUUwOC+f44cBT>0Ce>Fg=t64|>9!zTKy{0>E447`3$YZ?4J{>2??t$X zT%lJZT#PLR-O<{Rdr%0D?un%U7g}?4To8JZi(Zz0uZrbgcXslEN zv|5KtDA2CMrPvX59rQ&FE~B-Ybx5MRr9(2ckfbk`0{Wq~cDz5i%*`8sRU7{0-%g zOnL}bZKPL_2L}zsQa~V16%Ve(UInAj+R#T+D9+->Qowa+%~|jn&m$Mkk~*yAIks^3 zF81W?4*d4Joczuq@qCPHH+yQz&ogS`_@a5|*w~-SdD5GVLq3jFFb1vGtAHl5 z>vcVLL|rd^0R%y7HS0B&>Xu&PsD-5Cu@o=?Es^SX!42fH-UakWG6ncoLAy>9DFjE~ zgr$HYwB~5kshC{!E^xIw%|05@ns*=i7jv#?VsJBh6ih;^b(l$-um+F=dK58L}fu(>_v_wLmfuM|BNMd~q0zdXBC`W61W^SP{oZ(h11yrClXFx3~ z$%QkxdK9$hyyg&VINX&d9zDPYEllJicOPV<26o}U_218+3L_O%qt)8f&_KI35<8-< zjlN(4&|1yf1gLIlqo{>swO9(6g4WVzD!C%Gsly%x^=NI{OrtQIVLFxqW}vmSX&@JA zGcPxezq+&)U)L_4J4eU!Z6c9vfNi&CB5z@@X9tDkPZPLu6Fr49TECtL*Ye}0F%YzhxqaVUjz{6tz z_`-OC8kP>LsD&g?VkuxXT1$tg$Yts9G`18xgVv_QvlN1(*I+4NEm}*5=g37m+_z;7 zTN~SnPk*cjA9paB!#eaRcpk0RVLb)fb=ZI%(bmB{m|vi!I=VOIG1TEjYFIkFL@gwF z8A}1Lpd}LhhTTXmBniJ^U&S5;o6y>{c#XnvhS#wa@CI6Q27D$qlM82X^_D$(!RitG z#Urn?4h{R+s5cI>lm36$duP17sdqhxH!)Jd7PMNMw`ibUo42te+S-_h@K#!iv>A=s zY@>#y%{$aWvUjl*@E%$s)3=uGNJ>_YUh{H$Ugw$4n}*)Z=dawzc5P1JO*`*qd1)Oue1Ibr ze27--^$|^E*Xv{Kh_+tlVf_g$)zSTudtDIrQr(j4Q)(gTJ}d=%hL%9-6E%EJF7syk z1y*h9d`TW0^c9u@zD8>fLUq0&mpSNLtZLBjsLu|%pGHLacrj1j1GE%@{vCwxscwP( zKrPJqM=S*#L~A`4hsgCm&&6TvRqzv98~M)^inIKJrGO)7t>@xba?x|~$cb0jl9=PH zP3;}5ZM=)O+>*<73O7%mx|wZDSjq-wbmAL7JfF992H19YPYy?Mq=MhjYPo)=iNbOj z`dd1`C*X%;=qbDs|D@_0l;+DaGCZk1W3bQXNzcj1@%l0br)HLVy}s1!tgOsI>7Jad zY;RgdT2^{iW=4*YR^m-f_4v|qvQyIxI8HOscdlRr@JrzU{DB?G?k?^&xc3fs4+Vcx v1f43CfwBPng{@*Z5gIiGzn8_YUEw5k&=;#|DXA$La0;8X9tEedEJ^tc zquIp@f(7gp0Rbx(uz=W5-_!SaSl$CmJiob{*}H$gn|wZJ$~kA}_dPSacamLLIlf}- z4G;xd(yPe|H8wEIXB@ z1ZC=M*;&XVvdU8i#e>3jOOur@)BK%$KHr!U*`;mHY}r-FZdv82{uG{@A=4`>D@V8U zk22+^ExQNXCdnSHvS*e*jYH}JL8~q*kDMfr%4(Ipvix0mvbMgurnb(Kz5NwDDjQlg zK59|4Y6Mp)+}MyYE@CxQjcg1x*INyCv}&9cpipE&ZM3$rp(-4(15wMa;)>QbRMpl8 zCRmHt@ebRh>g zt8|XZfswv#Wxba4{mc0FN}u`B+m`vruOU;OVav>*zbk}WlNGhhnj{DLyYb9atEH*2 zmJ1=ztn{bzL_;Ih2kL6CvBGlj9DjG7!MfFa%<2QS9O8$n=`B}B8y4m|Io)q z&p%C54zuO(V1$ctPUOtW$ly7VArW91z^Y_WuMSVHb9zkX+K4lAZi4GPXdNMB8MOLv zRz^arPo8hfQ7)zn5O<-F7ddgf+oL%y!cpal;Nm2L3eO3govBm=8)I@zf?zC^T_U8= z1moh$E_LZ2kEjVkUZ(Z?<58C@{S%Y)U*Wm_Jaw|FI`mIU=)V%ms)f7?%GA3xfE{p7 zR?wCq7eyGcmXI}0Ebqc(#P%AuAkBR6jDtr1_Dju?8~J2KY+0+grX+Dq^_(kTan*Hj z)hDWDl|J4)=Z0!i#%wuTam`8MYV(|{ zP;uSDTstaq_+33W0d_0I&J%LJ2D>c*cDpSXxWMi}>_Qwa54;Np4^u@4FP zu;zM1ajmd@=JM7QbMyQ(bIP7HQ$A|T6^iSzB(BFj=PFiQPjqlSnc#W~TAvnjCA6xK z=vAy0AHC1Gh*l$NjgZeeQT);HTs-P|rT>K_{cAnfU!wG{WBt#S{}_`mCiK4qWiJc) zir$x36Z`U-Enjz0tViqyAvfxM*`)Sm^=IwIz_dfgyaOkhu?3?{`GzewE3Pd`Tw6Wo zDpg!>Dz4=*xh=u99a?t?`Ig?7x8oIm$3^rmqTUnoeZ4O`<53?d{U0Xj{|NfMyf-?} zosU_*x-ITX$lncNp9uLWgr(xi{~1c%gYO;sM}BV0JuaBN$oWFZeNGM^vM-U7^YolG z#@Z1*%-yYPjlA2BH(yPeXzm-GZq8qM(C9StWYg!HY3}>&4B1aegBMa%#SYOpAhvgA^)wn??622KT7|>B>g{o zuHU)${G#+nWAfL8{@pv7%cG#u=cSQXmnWAg3{PRE-pjV1+I?!5Ow8|P2Jcg7%3egoZX#U zK|RxTVo)NlFH_-W^B1|7*Bn~q|u>d*^;u1h%e<_XX1pl(Za zG}b~L4#lVs`nnFsV8kTHax3UKwz>|-!=>wR0*`Duk*%u^Cn1EhPv%xoU$(jqr@)0e z^p~rQhx+$7^XJ}gJUd{YLH*cc)2VD#9WoH;uES|O;dvd@t&93&t#}d&q!Mlgm9o{OL*a6ghH+Ps z4#%9E^c-Y(O{y+tWMD0jE{u_hz9u~vBNlxgw}M8ng=qC2FM~^;?U6jP>3p^>Y6&xNx@j>^yAro10?JdUu$)qRX)+jb@Kc1_ND)u^2JRCEN-UZ1t9pgUi{yOS$Xf9}f>FP2g70Wo-49o?rgE|p&ZqzG~;q{iQJ0DeHEo^x-Mw8Ijq*r3ZqN}+TbQN2O?u55-050A%{98!< zP>^SA3bA!zhY`lFU~wy`hOP4oyzP_W;#bJX3a6UGYxEZW` zW#KY&>EW>^ML5!?TDIyKOhKZ%HdA@R^V+C;Bh_Io)Mjst>e1J=X~2ld8o3qJ#1>@g z?I7XOZwKO$O;NV4y4VQe>}GBSO=Ih1b6u{63vY*M>whxV+?Zz0$Qp0Hkrpy(I(uwt zVXNwJ4FcVDxRxhS2QTlR&J%A2mckWuLyTI{*JZd4BL=yiTR}6~>N3oNOPApW9@%sw zTUQxwLI`Ky%&njpTV00PaG?x~jo*!3eR`V*ey=vKYnouv9QN4M##WW#76iJ>FqbF1 zE`#&HyA?~J3>#xK4}D#R`4}60nwN1aXgOOAdLLYl=KHzp z0(}4;PI{1AK@YLjpbx|4Bt61i1^OuF+@LFv;W4Q5xO)ss#X%oOUxPk@5o>;uTR~5; zg=Y2sei|-)uvhZPrd4cR#LpmvvsZH~XboE@o458^xNxwyZErUg?(1e&%^PA4TGZ2| z=h$P@^K4ZaUO=F`3~PD9>oPb`yLDKK%TUpaKiXf!K-b|VjF{wQZUw!$w%Qfvw)~jc_?ho4Bh;-@u%kbTcwM-f-t(w*^aK!?(w1EBYGrO^jIc zHf{xNXRASXz~yLui@Pq+x8dQWceoYwE?W)y9$Ze+``lHaJ2B@5{QwyrgE~*H53y7n z^ds~&=*Jkb=3U$h+Ravjegc=H`BUz?KtF?rlRoEG&>prLbT3>^(ihxSp!+bF0JYTj zK9N8)LSM4Se=c7cii!dS`C+TLI8;1&?IDnE*Uef- diff --git a/src/download_images.py b/src/download_images.py index 398dc20..abe4660 100644 --- a/src/download_images.py +++ b/src/download_images.py @@ -8,9 +8,13 @@ def download_images(image_urls, download_path="dataset/train/images"): """ Downloads images from a list of URLs and saves them to the specified directory. + Maintains index alignment by returning tuples of (original_index, file_path) + so that ranking algorithms can correctly map back to original URLs. + :param image_urls: List of image URLs to download. :param download_path: Directory to save downloaded images. - :return: List of file paths for successfully downloaded images. + :return: List of tuples (original_index, file_path) for successfully downloaded images, + preserving which position in the input list each downloaded image came from. """ print("Starting image download...") # Debugging statement @@ -18,22 +22,24 @@ def download_images(image_urls, download_path="dataset/train/images"): if not os.path.exists(download_path): os.makedirs(download_path) - # List to hold paths of successfully downloaded images + # List to hold (original_index, file_path) tuples + # This preserves alignment between downloaded images and input URLs downloaded_paths = [] # Iterate over the image URLs and download each image for idx, url in enumerate(image_urls): - print(f"Attempting to download: {url}") # Debugging statement + print(f"Attempting to download ({idx}/{len(image_urls)}): {url}") try: - response = requests.get(url) + response = requests.get(url, timeout=10) if response.status_code == 200: file_path = os.path.join(download_path, f"image_{idx}.jpg") with open(file_path, "wb") as f: f.write(response.content) print(f"Downloaded: {file_path}") - downloaded_paths.append(file_path) # Add path to list + # Store both the original index and the file path + downloaded_paths.append((idx, file_path)) else: - print(f"Failed to download {url}") + print(f"Failed to download {url}: status {response.status_code}") except Exception as e: print(f"Error downloading {url}: {e}") @@ -42,3 +48,4 @@ def download_images(image_urls, download_path="dataset/train/images"): print("No images were downloaded.") return downloaded_paths + diff --git a/src/main.py b/src/main.py index 658e035..6e4214f 100644 --- a/src/main.py +++ b/src/main.py @@ -139,21 +139,34 @@ async def search( os.makedirs(temp_download_path, exist_ok=True) try: - image_paths = download_images(images_subset, temp_download_path) + # Extract URLs from metadata if using new format + if images_subset and isinstance(images_subset[0], dict): + urls_to_download = [r['url'] for r in images_subset] + else: + urls_to_download = images_subset - # Select balanced images (70% relevance, 30% dissimilarity) + image_paths = download_images(urls_to_download, temp_download_path) + + # Select balanced images (60% popularity, 25% caption, 15% dissimilarity) selected_images = select_balanced_images( images_subset, image_paths, + query=query, num_images=min(9, len(images_subset)), - relevance_weight=0.7 + popularity_weight=0.6, + caption_weight=0.25, + dissimilarity_weight=0.15 ) logger.info( f"Selected {len(selected_images)} balanced images for query: {query} (page {page})") except Exception as e: logger.warning( f"Balanced selection failed, falling back to first 9 images: {e}") - selected_images = images_subset[:9] + # Extract URLs from metadata if needed + if images_subset and isinstance(images_subset[0], dict): + selected_images = [r['url'] for r in images_subset[:9]] + else: + selected_images = images_subset[:9] finally: # Clean up temporary downloads if os.path.exists(temp_download_path): diff --git a/src/scrape_similar.py b/src/scrape_similar.py index 7110599..05a532c 100644 --- a/src/scrape_similar.py +++ b/src/scrape_similar.py @@ -15,6 +15,9 @@ def scrape_similar_images( Scrape similar images for training augmentation. Uses multiple query variations to find diverse training images. Falls back gracefully if search fails. + + Returns: + List of image URLs (strips metadata for compatibility) """ similar_images = [] @@ -41,16 +44,18 @@ def scrape_similar_images( try: logger.debug(f"Attempting search with query: {query}") - images = search_images( + results = search_images( query, api_key, search_engine_id, num_results=num_results_per_image ) - if images: - logger.info(f"Got {len(images)} images from query: {query}") - similar_images.extend(images) + if results: + # Extract URLs from metadata dicts + urls = [r['url'] for r in results] + logger.info(f"Got {len(urls)} images from query: {query}") + similar_images.extend(urls) else: logger.debug(f"No images from query: {query}") diff --git a/src/search_images.py b/src/search_images.py index f992140..9a991c0 100644 --- a/src/search_images.py +++ b/src/search_images.py @@ -11,6 +11,13 @@ def search_images(query, api_key, search_engine_id, num_results=10): """ Search for images using Google Custom Search API. Falls back to Bing Images if Google fails (no API key needed). + + Returns: + List of dicts containing image metadata: { + 'url': image_url, + 'title': caption/title, + 'snippet': description + } """ images = [] google_error = None @@ -53,7 +60,10 @@ def _search_google_custom_search( api_key, search_engine_id, num_results=10): - """Search using Google Custom Search API""" + """ + Search using Google Custom Search API. + Extracts image URLs, titles, and snippets for relevancy ranking. + """ images = [] results_per_page = 10 start_index = 1 @@ -78,7 +88,11 @@ def _search_google_custom_search( break for item in data['items']: - images.append(item['link']) + images.append({ + 'url': item['link'], + 'title': item.get('title', ''), + 'snippet': item.get('snippet', '') + }) start_index += results_per_page @@ -93,8 +107,7 @@ def _search_google_custom_search( def _search_bing_images(query, num_results=10): """ Search using Bing Images (free, no API key required) - Scrapes image URLs from Bing image search with retry logic. - Strips problematic filter syntax before searching. + Scrapes image URLs and captions from Bing image search. """ images = [] max_retries = 3 @@ -137,10 +150,9 @@ def _search_bing_images(query, num_results=10): raise Exception( f"Bing Images returned status {response.status_code}") - # Extract image URLs from the HTML response using regex - # Bing stores lazy-loaded images in data-src attributes - # These are Bing image proxy URLs (tse1.mm.bing.net, etc.) - image_pattern = r']+data-src="([^"]+)"' + # Extract image URLs and captions from HTML + # Bing stores images in img tags with data-src attributes + image_pattern = r']+data-src="([^"]+)"[^>]+alt="([^"]*)"' matches = re.findall(image_pattern, response.text) if not matches: @@ -152,13 +164,17 @@ def _search_bing_images(query, num_results=10): continue raise Exception("No images found on Bing Images after retries") - # Process URLs and decode HTML entities - for url in matches: + # Process URLs and captions + for url, caption in matches: if url.startswith('http') and len(images) < num_results: # Decode HTML entities (e.g., & to &) url = url.replace('&', '&') url = url.replace('\\/', '/') - images.append(url) + images.append({ + 'url': url, + 'title': caption, + 'snippet': caption + }) if not images: logger.debug( diff --git a/src/select_balanced_images.py b/src/select_balanced_images.py index 738e370..7ebc978 100644 --- a/src/select_balanced_images.py +++ b/src/select_balanced_images.py @@ -1,14 +1,18 @@ """ -Balanced image selection combining search relevance with visual dissimilarity. +Balanced image selection combining search relevance, caption relevance, +and visual dissimilarity. Strategy: -1. Images are initially ranked by search engine (relevance score based on position) -2. Extract visual features from all images using ResNet50 -3. Select images that balance: - - High relevance (early in search results) +1. Images are initially ranked by search engine (position-based relevance) +2. Compute caption relevance scores based on keyword matching +3. Extract visual features from all images using ResNet50 +4. Select images that balance: + - High relevance (early in search results + relevant captions) + - Caption relevance (semantic match with query) - Visual dissimilarity (diverse appearance) -This ensures training data is both relevant to the search query and diverse in appearance. +This ensures training data is relevant to the search query both semantically +(via captions) and visually (via ResNet50 features), while maintaining diversity. """ import numpy as np @@ -17,6 +21,7 @@ from torchvision import models, transforms import torch import logging +from src.utils.caption_relevance import compute_batch_caption_relevance logger = logging.getLogger(__name__) @@ -60,45 +65,65 @@ def extract_features(image_path): def select_balanced_images( - image_urls, + image_results, image_paths, + query="", num_images=9, - relevance_weight=0.7): + popularity_weight=0.6, + caption_weight=0.25, + dissimilarity_weight=0.15): """ - Selects images that balance search relevance with visual dissimilarity. + Selects images balancing search popularity, caption relevance, + and visual dissimilarity. + + Strategy: Prioritize quality (popularity/search ranking) over diversity. + The search engine has already ranked results by relevance, so we heavily weight + position-based popularity. Caption relevance adds semantic understanding, + and visual dissimilarity is used as a tiebreaker for minor diversity. Args: - image_urls: List of image URLs (in order of relevance from search engine) - image_paths: List of local file paths corresponding to image_urls + image_results: List of result dicts with 'url', 'title', 'snippet' keys + OR list of URLs (for backward compatibility) + image_paths: List of tuples (original_index, file_path) from download_images() + Each tuple preserves which position in image_results this came from + query: Original search query (required for caption relevance scoring) num_images: Number of images to select (default 9) - relevance_weight: Weight for relevance score (0-1). Dissimilarity weight = 1 - relevance_weight - Default 0.7 means 70% relevance, 30% dissimilarity + popularity_weight: Weight for search result position (default 0.6) + caption_weight: Weight for caption relevance (default 0.25) + dissimilarity_weight: Weight for visual dissimilarity (default 0.15) Returns: - List of selected image URLs, balanced between relevance and dissimilarity + List of selected image URLs, prioritizing popularity and quality """ - if len(image_urls) < num_images: + if len(image_results) < num_images: logger.warning( - f"Requested {num_images} images but only {len(image_urls)} available") - return image_urls + f"Requested {num_images} images but only {len(image_results)} available") + return _extract_urls(image_results) + + # Normalize weights to sum to 1.0 + total_weight = popularity_weight + caption_weight + dissimilarity_weight + popularity_weight = popularity_weight / total_weight + caption_weight = caption_weight / total_weight + dissimilarity_weight = dissimilarity_weight / total_weight - # Extract features from all images + # Extract features from downloaded images, preserving original indices features_list = [] - valid_indices = [] + original_indices = [] # Track which image_results index each feature came from - for idx, path in enumerate(image_paths): - feature = extract_features(path) + for original_idx, file_path in image_paths: + feature = extract_features(file_path) if feature is not None: features_list.append(feature) - valid_indices.append(idx) + original_indices.append(original_idx) else: - logger.debug(f"Skipping image {idx} - could not extract features") + logger.debug(f"Skipping image from index {original_idx} - could not extract features") if len(features_list) < num_images: logger.warning( - f"Only {len(features_list)} images have valid features, returning top {min(len(image_urls), num_images)}") - return image_urls[:min(len(image_urls), num_images)] + f"Only {len(features_list)} images have valid features, returning top {min(len(image_results), num_images)}") + urls = _extract_urls(image_results) + return urls[:min(len(urls), num_images)] features = np.array(features_list) @@ -106,19 +131,23 @@ def select_balanced_images( # Compute cosine distance matrix between image features distance_matrix = cosine_distances(features) - # Calculate dissimilarity score for each image (sum of distances to all - # others) + # Calculate dissimilarity score for each image (sum of distances to all) dissimilarity_scores = np.sum(distance_matrix, axis=1) - # Normalize both scores to 0-1 range - dissimilarity_weight = 1 - relevance_weight + # Popularity score: based on ORIGINAL position in search results + # Map original_index (0 to n-1) to popularity (1.0 to 0.0) + # Higher original_index = lower popularity, lower score + max_original_idx = max(original_indices) if original_indices else 0 + popularity_scores = 1.0 - np.array(original_indices) / max(1, max_original_idx) - # Relevance score: images earlier in search results have higher relevance - # Map position (0 to len-1) to relevance (1.0 to 0.0) - relevance_scores = 1.0 - \ - np.arange(len(features_list)) / max(1, len(features_list) - 1) + # Caption relevance scores (if query provided) + caption_scores = np.zeros(len(features_list)) + if query and _is_metadata_format(image_results): + caption_scores_list = compute_batch_caption_relevance( + [image_results[i] for i in original_indices], query) + caption_scores = np.array(caption_scores_list) - # Normalize dissimilarity scores to 0-1 range + # Normalize all scores to 0-1 range if dissimilarity_scores.max() > dissimilarity_scores.min(): dissimilarity_scores_norm = ( dissimilarity_scores - dissimilarity_scores.min()) / ( @@ -126,21 +155,57 @@ def select_balanced_images( else: dissimilarity_scores_norm = dissimilarity_scores - # Combined score: weighted combination of relevance and dissimilarity - combined_scores = (relevance_weight * relevance_scores + - dissimilarity_weight * dissimilarity_scores_norm) + if caption_scores.max() > caption_scores.min(): + caption_scores_norm = ( + caption_scores - caption_scores.min()) / ( + caption_scores.max() - caption_scores.min()) + else: + caption_scores_norm = caption_scores + + # Combined score: heavily weighted toward popularity (search engine ranking), + # with caption relevance for semantic understanding, and dissimilarity as + # a tiebreaker for minor diversity + combined_scores = ( + popularity_weight * popularity_scores + + caption_weight * caption_scores_norm + + dissimilarity_weight * dissimilarity_scores_norm + ) # Select top num_images indices by combined score selected_feature_indices = np.argsort(combined_scores)[-num_images:][::-1] - # Map back to original image indices - selected_indices = [valid_indices[idx] for idx in selected_feature_indices] + # Map back to original image_results indices + selected_original_indices = [original_indices[idx] for idx in selected_feature_indices] # Return selected image URLs - selected_images = [image_urls[idx] for idx in selected_indices] + urls = _extract_urls(image_results) + selected_images = [urls[idx] for idx in selected_original_indices] logger.info( - f"Selected {len(selected_images)} images using balanced strategy " - f"(relevance_weight={relevance_weight}, dissimilarity_weight={dissimilarity_weight})") + f"Selected {len(selected_images)} images using quality-first strategy " + f"(popularity={popularity_weight:.2f}, caption={caption_weight:.2f}, " + f"dissimilarity={dissimilarity_weight:.2f})") return selected_images + + +def _extract_urls(image_results): + """ + Extract URLs from image results. + + Handles both metadata format (list of dicts) and legacy format (list of strings). + """ + if not image_results: + return [] + + if isinstance(image_results[0], dict): + return [r['url'] for r in image_results] + else: + return image_results + + +def _is_metadata_format(image_results): + """Check if image_results are in metadata format (dicts) or legacy format (strings).""" + if not image_results: + return False + return isinstance(image_results[0], dict) diff --git a/src/utils/caption_relevance.py b/src/utils/caption_relevance.py new file mode 100644 index 0000000..e1f9641 --- /dev/null +++ b/src/utils/caption_relevance.py @@ -0,0 +1,135 @@ +""" +Caption relevance scoring for image search results. + +This module provides functions to compute semantic and keyword-based relevance +scores for image captions against the original search query. +""" + +import logging +import re +from collections import Counter + +logger = logging.getLogger(__name__) + + +def compute_caption_relevance(caption, query): + """ + Compute caption relevance score to the search query (0-1). + + This uses a multi-factor approach: + 1. Keyword overlap: Percentage of query words present in caption + 2. Keyword position: Earlier occurrences weighted more heavily + 3. Length normalization: Balanced against caption length + + Args: + caption: Image caption or title string + query: Original search query string + + Returns: + float: Relevance score between 0.0 and 1.0 + """ + if not caption or not query: + return 0.0 + + # Normalize text: lowercase, remove punctuation + caption_clean = _normalize_text(caption) + query_clean = _normalize_text(query) + + # Split into words + caption_words = set(caption_clean.split()) + query_words = set(query_clean.split()) + + # Remove common stop words to focus on meaningful terms + query_words = query_words - _get_stop_words() + caption_words = caption_words - _get_stop_words() + + if not query_words: + return 0.0 + + # Calculate keyword overlap + matching_words = query_words & caption_words + overlap_ratio = len(matching_words) / len(query_words) + + # Calculate position-weighted score (earlier matches count more) + position_score = _compute_position_score( + caption_clean, query_clean, matching_words) + + # Length normalization: penalize very long captions with few matches + length_penalty = min(1.0, 100 / len(caption_words)) if caption_words else 0.0 + + # Combined score: average with length penalty + relevance_score = (overlap_ratio * 0.5 + + position_score * 0.3 + + length_penalty * 0.2) + + return min(1.0, max(0.0, relevance_score)) + + +def _normalize_text(text): + """Normalize text for comparison (lowercase, remove punctuation).""" + text = text.lower() + # Remove punctuation and extra whitespace + text = re.sub(r'[^\w\s]', ' ', text) + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +def _get_stop_words(): + """Return a set of common English stop words.""" + return { + 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', + 'of', 'with', 'by', 'from', 'up', 'about', 'into', 'through', 'as', + 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', + 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', + 'may', 'might', 'must', 'can', 'this', 'that', 'these', 'those', + 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which', + 'who', 'when', 'where', 'why', 'how' + } + + +def _compute_position_score(caption, query, matching_words): + """ + Compute position-weighted score. + + Matching words that appear early in the caption score higher. + """ + if not matching_words: + return 0.0 + + caption_words = caption.split() + position_scores = [] + + for word in matching_words: + # Find first occurrence of the word + for i, caption_word in enumerate(caption_words): + if caption_word == word: + # Earlier words (lower indices) get higher scores + # Position 0 = 1.0, position increases = score decreases + position_score = 1.0 / (1.0 + i / 10.0) + position_scores.append(position_score) + break + + return sum(position_scores) / len( + matching_words) if position_scores else 0.0 + + +def compute_batch_caption_relevance(results, query): + """ + Compute caption relevance scores for a batch of search results. + + Args: + results: List of result dicts with 'title' and 'snippet' keys + query: Search query string + + Returns: + List of relevance scores corresponding to input results + """ + scores = [] + + for result in results: + # Combine title and snippet for more context + caption = f"{result.get('title', '')} {result.get('snippet', '')}" + score = compute_caption_relevance(caption, query) + scores.append(score) + + return scores diff --git a/test_caption_relevance.py b/test_caption_relevance.py new file mode 100644 index 0000000..9cd2b67 --- /dev/null +++ b/test_caption_relevance.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +""" +Simple test script to verify caption relevance scoring functionality. +""" + +import sys +sys.path.insert(0, '/Users/Brandon Shen/Documents/SearchVision') + +from src.utils.caption_relevance import ( + compute_caption_relevance, + compute_batch_caption_relevance +) + +# Test cases for caption relevance +test_cases = [ + { + "query": "dog", + "captions": [ + ("A beautiful golden retriever playing in the park", 0.8), # High relevance + ("Golden retriever on beach", 0.9), # Very high relevance + ("Cat sitting on a chair", 0.0), # No relevance + ("Dog training classes", 0.7), # Relevant but indirect + ("Fluffy dog breed guide", 0.8), # Relevant + ] + }, + { + "query": "cat sitting", + "captions": [ + ("Orange cat sitting on windowsill", 0.9), # Very high relevance + ("How to teach your cat to sit", 0.8), # Relevant + ("Dogs and their behavior", 0.0), # No relevance + ("Sitting meditation techniques", 0.2), # Partial relevance + ("Cat toys for active cats", 0.6), # Somewhat relevant + ] + }, +] + +print("=" * 70) +print("Caption Relevance Scoring Test") +print("=" * 70) + +for test_case in test_cases: + query = test_case["query"] + captions = test_case["captions"] + + print(f"\nQuery: '{query}'") + print("-" * 70) + + for caption, expected_approx in captions: + score = compute_caption_relevance(caption, query) + status = "✓" if abs(score - expected_approx) < 0.15 else "~" + print(f"{status} Caption: {caption[:50]:50} | Score: {score:.2f}") + +print("\n" + "=" * 70) +print("Batch Caption Relevance Test") +print("=" * 70) + +batch_results = [ + {"url": "http://example.com/1", "title": "Beautiful Golden Retriever", "snippet": "A friendly golden retriever dog"}, + {"url": "http://example.com/2", "title": "Cat Sleeping", "snippet": "Orange cat taking a nap"}, + {"url": "http://example.com/3", "title": "Dog Training Guide", "snippet": "Learn how to train your dog"}, +] + +query = "dog" +scores = compute_batch_caption_relevance(batch_results, query) + +print(f"\nQuery: '{query}'") +print("-" * 70) +for result, score in zip(batch_results, scores): + print(f"Title: {result['title']:40} | Score: {score:.2f}") + +print("\n" + "=" * 70) +print("All tests completed successfully!") +print("=" * 70)