diff --git a/.gitignore b/.gitignore index 4018537..3450029 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ dataset/train/labels app_logs.txt .aider* *.tex -runs/ \ No newline at end of file +runs/ +*.md \ No newline at end of file diff --git a/dataset/data.yaml b/dataset/data.yaml index 17f9731..42ad750 100644 --- a/dataset/data.yaml +++ b/dataset/data.yaml @@ -2,4 +2,4 @@ path: C:\Users\Brandon Shen\Documents\SearchVision\dataset train: C:\Users\Brandon Shen\Documents\SearchVision\dataset\train\images val: C:\Users\Brandon Shen\Documents\SearchVision\dataset\train\images names: - 0: soccer ball + 0: coffee cup diff --git a/dataset/train/labels.cache b/dataset/train/labels.cache index 721bc7d..6d84a21 100644 Binary files a/dataset/train/labels.cache and b/dataset/train/labels.cache differ diff --git a/src/download_images.py b/src/download_images.py index 398dc20..abe4660 100644 --- a/src/download_images.py +++ b/src/download_images.py @@ -8,9 +8,13 @@ def download_images(image_urls, download_path="dataset/train/images"): """ Downloads images from a list of URLs and saves them to the specified directory. + Maintains index alignment by returning tuples of (original_index, file_path) + so that ranking algorithms can correctly map back to original URLs. + :param image_urls: List of image URLs to download. :param download_path: Directory to save downloaded images. - :return: List of file paths for successfully downloaded images. + :return: List of tuples (original_index, file_path) for successfully downloaded images, + preserving which position in the input list each downloaded image came from. """ print("Starting image download...") # Debugging statement @@ -18,22 +22,24 @@ def download_images(image_urls, download_path="dataset/train/images"): if not os.path.exists(download_path): os.makedirs(download_path) - # List to hold paths of successfully downloaded images + # List to hold (original_index, file_path) tuples + # This preserves alignment between downloaded images and input URLs downloaded_paths = [] # Iterate over the image URLs and download each image for idx, url in enumerate(image_urls): - print(f"Attempting to download: {url}") # Debugging statement + print(f"Attempting to download ({idx}/{len(image_urls)}): {url}") try: - response = requests.get(url) + response = requests.get(url, timeout=10) if response.status_code == 200: file_path = os.path.join(download_path, f"image_{idx}.jpg") with open(file_path, "wb") as f: f.write(response.content) print(f"Downloaded: {file_path}") - downloaded_paths.append(file_path) # Add path to list + # Store both the original index and the file path + downloaded_paths.append((idx, file_path)) else: - print(f"Failed to download {url}") + print(f"Failed to download {url}: status {response.status_code}") except Exception as e: print(f"Error downloading {url}: {e}") @@ -42,3 +48,4 @@ def download_images(image_urls, download_path="dataset/train/images"): print("No images were downloaded.") return downloaded_paths + diff --git a/src/main.py b/src/main.py index 658e035..6e4214f 100644 --- a/src/main.py +++ b/src/main.py @@ -139,21 +139,34 @@ async def search( os.makedirs(temp_download_path, exist_ok=True) try: - image_paths = download_images(images_subset, temp_download_path) + # Extract URLs from metadata if using new format + if images_subset and isinstance(images_subset[0], dict): + urls_to_download = [r['url'] for r in images_subset] + else: + urls_to_download = images_subset - # Select balanced images (70% relevance, 30% dissimilarity) + image_paths = download_images(urls_to_download, temp_download_path) + + # Select balanced images (60% popularity, 25% caption, 15% dissimilarity) selected_images = select_balanced_images( images_subset, image_paths, + query=query, num_images=min(9, len(images_subset)), - relevance_weight=0.7 + popularity_weight=0.6, + caption_weight=0.25, + dissimilarity_weight=0.15 ) logger.info( f"Selected {len(selected_images)} balanced images for query: {query} (page {page})") except Exception as e: logger.warning( f"Balanced selection failed, falling back to first 9 images: {e}") - selected_images = images_subset[:9] + # Extract URLs from metadata if needed + if images_subset and isinstance(images_subset[0], dict): + selected_images = [r['url'] for r in images_subset[:9]] + else: + selected_images = images_subset[:9] finally: # Clean up temporary downloads if os.path.exists(temp_download_path): diff --git a/src/scrape_similar.py b/src/scrape_similar.py index 7110599..05a532c 100644 --- a/src/scrape_similar.py +++ b/src/scrape_similar.py @@ -15,6 +15,9 @@ def scrape_similar_images( Scrape similar images for training augmentation. Uses multiple query variations to find diverse training images. Falls back gracefully if search fails. + + Returns: + List of image URLs (strips metadata for compatibility) """ similar_images = [] @@ -41,16 +44,18 @@ def scrape_similar_images( try: logger.debug(f"Attempting search with query: {query}") - images = search_images( + results = search_images( query, api_key, search_engine_id, num_results=num_results_per_image ) - if images: - logger.info(f"Got {len(images)} images from query: {query}") - similar_images.extend(images) + if results: + # Extract URLs from metadata dicts + urls = [r['url'] for r in results] + logger.info(f"Got {len(urls)} images from query: {query}") + similar_images.extend(urls) else: logger.debug(f"No images from query: {query}") diff --git a/src/search_images.py b/src/search_images.py index f992140..9a991c0 100644 --- a/src/search_images.py +++ b/src/search_images.py @@ -11,6 +11,13 @@ def search_images(query, api_key, search_engine_id, num_results=10): """ Search for images using Google Custom Search API. Falls back to Bing Images if Google fails (no API key needed). + + Returns: + List of dicts containing image metadata: { + 'url': image_url, + 'title': caption/title, + 'snippet': description + } """ images = [] google_error = None @@ -53,7 +60,10 @@ def _search_google_custom_search( api_key, search_engine_id, num_results=10): - """Search using Google Custom Search API""" + """ + Search using Google Custom Search API. + Extracts image URLs, titles, and snippets for relevancy ranking. + """ images = [] results_per_page = 10 start_index = 1 @@ -78,7 +88,11 @@ def _search_google_custom_search( break for item in data['items']: - images.append(item['link']) + images.append({ + 'url': item['link'], + 'title': item.get('title', ''), + 'snippet': item.get('snippet', '') + }) start_index += results_per_page @@ -93,8 +107,7 @@ def _search_google_custom_search( def _search_bing_images(query, num_results=10): """ Search using Bing Images (free, no API key required) - Scrapes image URLs from Bing image search with retry logic. - Strips problematic filter syntax before searching. + Scrapes image URLs and captions from Bing image search. """ images = [] max_retries = 3 @@ -137,10 +150,9 @@ def _search_bing_images(query, num_results=10): raise Exception( f"Bing Images returned status {response.status_code}") - # Extract image URLs from the HTML response using regex - # Bing stores lazy-loaded images in data-src attributes - # These are Bing image proxy URLs (tse1.mm.bing.net, etc.) - image_pattern = r']+data-src="([^"]+)"' + # Extract image URLs and captions from HTML + # Bing stores images in img tags with data-src attributes + image_pattern = r']+data-src="([^"]+)"[^>]+alt="([^"]*)"' matches = re.findall(image_pattern, response.text) if not matches: @@ -152,13 +164,17 @@ def _search_bing_images(query, num_results=10): continue raise Exception("No images found on Bing Images after retries") - # Process URLs and decode HTML entities - for url in matches: + # Process URLs and captions + for url, caption in matches: if url.startswith('http') and len(images) < num_results: # Decode HTML entities (e.g., & to &) url = url.replace('&', '&') url = url.replace('\\/', '/') - images.append(url) + images.append({ + 'url': url, + 'title': caption, + 'snippet': caption + }) if not images: logger.debug( diff --git a/src/select_balanced_images.py b/src/select_balanced_images.py index 738e370..7ebc978 100644 --- a/src/select_balanced_images.py +++ b/src/select_balanced_images.py @@ -1,14 +1,18 @@ """ -Balanced image selection combining search relevance with visual dissimilarity. +Balanced image selection combining search relevance, caption relevance, +and visual dissimilarity. Strategy: -1. Images are initially ranked by search engine (relevance score based on position) -2. Extract visual features from all images using ResNet50 -3. Select images that balance: - - High relevance (early in search results) +1. Images are initially ranked by search engine (position-based relevance) +2. Compute caption relevance scores based on keyword matching +3. Extract visual features from all images using ResNet50 +4. Select images that balance: + - High relevance (early in search results + relevant captions) + - Caption relevance (semantic match with query) - Visual dissimilarity (diverse appearance) -This ensures training data is both relevant to the search query and diverse in appearance. +This ensures training data is relevant to the search query both semantically +(via captions) and visually (via ResNet50 features), while maintaining diversity. """ import numpy as np @@ -17,6 +21,7 @@ from torchvision import models, transforms import torch import logging +from src.utils.caption_relevance import compute_batch_caption_relevance logger = logging.getLogger(__name__) @@ -60,45 +65,65 @@ def extract_features(image_path): def select_balanced_images( - image_urls, + image_results, image_paths, + query="", num_images=9, - relevance_weight=0.7): + popularity_weight=0.6, + caption_weight=0.25, + dissimilarity_weight=0.15): """ - Selects images that balance search relevance with visual dissimilarity. + Selects images balancing search popularity, caption relevance, + and visual dissimilarity. + + Strategy: Prioritize quality (popularity/search ranking) over diversity. + The search engine has already ranked results by relevance, so we heavily weight + position-based popularity. Caption relevance adds semantic understanding, + and visual dissimilarity is used as a tiebreaker for minor diversity. Args: - image_urls: List of image URLs (in order of relevance from search engine) - image_paths: List of local file paths corresponding to image_urls + image_results: List of result dicts with 'url', 'title', 'snippet' keys + OR list of URLs (for backward compatibility) + image_paths: List of tuples (original_index, file_path) from download_images() + Each tuple preserves which position in image_results this came from + query: Original search query (required for caption relevance scoring) num_images: Number of images to select (default 9) - relevance_weight: Weight for relevance score (0-1). Dissimilarity weight = 1 - relevance_weight - Default 0.7 means 70% relevance, 30% dissimilarity + popularity_weight: Weight for search result position (default 0.6) + caption_weight: Weight for caption relevance (default 0.25) + dissimilarity_weight: Weight for visual dissimilarity (default 0.15) Returns: - List of selected image URLs, balanced between relevance and dissimilarity + List of selected image URLs, prioritizing popularity and quality """ - if len(image_urls) < num_images: + if len(image_results) < num_images: logger.warning( - f"Requested {num_images} images but only {len(image_urls)} available") - return image_urls + f"Requested {num_images} images but only {len(image_results)} available") + return _extract_urls(image_results) + + # Normalize weights to sum to 1.0 + total_weight = popularity_weight + caption_weight + dissimilarity_weight + popularity_weight = popularity_weight / total_weight + caption_weight = caption_weight / total_weight + dissimilarity_weight = dissimilarity_weight / total_weight - # Extract features from all images + # Extract features from downloaded images, preserving original indices features_list = [] - valid_indices = [] + original_indices = [] # Track which image_results index each feature came from - for idx, path in enumerate(image_paths): - feature = extract_features(path) + for original_idx, file_path in image_paths: + feature = extract_features(file_path) if feature is not None: features_list.append(feature) - valid_indices.append(idx) + original_indices.append(original_idx) else: - logger.debug(f"Skipping image {idx} - could not extract features") + logger.debug(f"Skipping image from index {original_idx} - could not extract features") if len(features_list) < num_images: logger.warning( - f"Only {len(features_list)} images have valid features, returning top {min(len(image_urls), num_images)}") - return image_urls[:min(len(image_urls), num_images)] + f"Only {len(features_list)} images have valid features, returning top {min(len(image_results), num_images)}") + urls = _extract_urls(image_results) + return urls[:min(len(urls), num_images)] features = np.array(features_list) @@ -106,19 +131,23 @@ def select_balanced_images( # Compute cosine distance matrix between image features distance_matrix = cosine_distances(features) - # Calculate dissimilarity score for each image (sum of distances to all - # others) + # Calculate dissimilarity score for each image (sum of distances to all) dissimilarity_scores = np.sum(distance_matrix, axis=1) - # Normalize both scores to 0-1 range - dissimilarity_weight = 1 - relevance_weight + # Popularity score: based on ORIGINAL position in search results + # Map original_index (0 to n-1) to popularity (1.0 to 0.0) + # Higher original_index = lower popularity, lower score + max_original_idx = max(original_indices) if original_indices else 0 + popularity_scores = 1.0 - np.array(original_indices) / max(1, max_original_idx) - # Relevance score: images earlier in search results have higher relevance - # Map position (0 to len-1) to relevance (1.0 to 0.0) - relevance_scores = 1.0 - \ - np.arange(len(features_list)) / max(1, len(features_list) - 1) + # Caption relevance scores (if query provided) + caption_scores = np.zeros(len(features_list)) + if query and _is_metadata_format(image_results): + caption_scores_list = compute_batch_caption_relevance( + [image_results[i] for i in original_indices], query) + caption_scores = np.array(caption_scores_list) - # Normalize dissimilarity scores to 0-1 range + # Normalize all scores to 0-1 range if dissimilarity_scores.max() > dissimilarity_scores.min(): dissimilarity_scores_norm = ( dissimilarity_scores - dissimilarity_scores.min()) / ( @@ -126,21 +155,57 @@ def select_balanced_images( else: dissimilarity_scores_norm = dissimilarity_scores - # Combined score: weighted combination of relevance and dissimilarity - combined_scores = (relevance_weight * relevance_scores + - dissimilarity_weight * dissimilarity_scores_norm) + if caption_scores.max() > caption_scores.min(): + caption_scores_norm = ( + caption_scores - caption_scores.min()) / ( + caption_scores.max() - caption_scores.min()) + else: + caption_scores_norm = caption_scores + + # Combined score: heavily weighted toward popularity (search engine ranking), + # with caption relevance for semantic understanding, and dissimilarity as + # a tiebreaker for minor diversity + combined_scores = ( + popularity_weight * popularity_scores + + caption_weight * caption_scores_norm + + dissimilarity_weight * dissimilarity_scores_norm + ) # Select top num_images indices by combined score selected_feature_indices = np.argsort(combined_scores)[-num_images:][::-1] - # Map back to original image indices - selected_indices = [valid_indices[idx] for idx in selected_feature_indices] + # Map back to original image_results indices + selected_original_indices = [original_indices[idx] for idx in selected_feature_indices] # Return selected image URLs - selected_images = [image_urls[idx] for idx in selected_indices] + urls = _extract_urls(image_results) + selected_images = [urls[idx] for idx in selected_original_indices] logger.info( - f"Selected {len(selected_images)} images using balanced strategy " - f"(relevance_weight={relevance_weight}, dissimilarity_weight={dissimilarity_weight})") + f"Selected {len(selected_images)} images using quality-first strategy " + f"(popularity={popularity_weight:.2f}, caption={caption_weight:.2f}, " + f"dissimilarity={dissimilarity_weight:.2f})") return selected_images + + +def _extract_urls(image_results): + """ + Extract URLs from image results. + + Handles both metadata format (list of dicts) and legacy format (list of strings). + """ + if not image_results: + return [] + + if isinstance(image_results[0], dict): + return [r['url'] for r in image_results] + else: + return image_results + + +def _is_metadata_format(image_results): + """Check if image_results are in metadata format (dicts) or legacy format (strings).""" + if not image_results: + return False + return isinstance(image_results[0], dict) diff --git a/src/utils/caption_relevance.py b/src/utils/caption_relevance.py new file mode 100644 index 0000000..e1f9641 --- /dev/null +++ b/src/utils/caption_relevance.py @@ -0,0 +1,135 @@ +""" +Caption relevance scoring for image search results. + +This module provides functions to compute semantic and keyword-based relevance +scores for image captions against the original search query. +""" + +import logging +import re +from collections import Counter + +logger = logging.getLogger(__name__) + + +def compute_caption_relevance(caption, query): + """ + Compute caption relevance score to the search query (0-1). + + This uses a multi-factor approach: + 1. Keyword overlap: Percentage of query words present in caption + 2. Keyword position: Earlier occurrences weighted more heavily + 3. Length normalization: Balanced against caption length + + Args: + caption: Image caption or title string + query: Original search query string + + Returns: + float: Relevance score between 0.0 and 1.0 + """ + if not caption or not query: + return 0.0 + + # Normalize text: lowercase, remove punctuation + caption_clean = _normalize_text(caption) + query_clean = _normalize_text(query) + + # Split into words + caption_words = set(caption_clean.split()) + query_words = set(query_clean.split()) + + # Remove common stop words to focus on meaningful terms + query_words = query_words - _get_stop_words() + caption_words = caption_words - _get_stop_words() + + if not query_words: + return 0.0 + + # Calculate keyword overlap + matching_words = query_words & caption_words + overlap_ratio = len(matching_words) / len(query_words) + + # Calculate position-weighted score (earlier matches count more) + position_score = _compute_position_score( + caption_clean, query_clean, matching_words) + + # Length normalization: penalize very long captions with few matches + length_penalty = min(1.0, 100 / len(caption_words)) if caption_words else 0.0 + + # Combined score: average with length penalty + relevance_score = (overlap_ratio * 0.5 + + position_score * 0.3 + + length_penalty * 0.2) + + return min(1.0, max(0.0, relevance_score)) + + +def _normalize_text(text): + """Normalize text for comparison (lowercase, remove punctuation).""" + text = text.lower() + # Remove punctuation and extra whitespace + text = re.sub(r'[^\w\s]', ' ', text) + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +def _get_stop_words(): + """Return a set of common English stop words.""" + return { + 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', + 'of', 'with', 'by', 'from', 'up', 'about', 'into', 'through', 'as', + 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', + 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', + 'may', 'might', 'must', 'can', 'this', 'that', 'these', 'those', + 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which', + 'who', 'when', 'where', 'why', 'how' + } + + +def _compute_position_score(caption, query, matching_words): + """ + Compute position-weighted score. + + Matching words that appear early in the caption score higher. + """ + if not matching_words: + return 0.0 + + caption_words = caption.split() + position_scores = [] + + for word in matching_words: + # Find first occurrence of the word + for i, caption_word in enumerate(caption_words): + if caption_word == word: + # Earlier words (lower indices) get higher scores + # Position 0 = 1.0, position increases = score decreases + position_score = 1.0 / (1.0 + i / 10.0) + position_scores.append(position_score) + break + + return sum(position_scores) / len( + matching_words) if position_scores else 0.0 + + +def compute_batch_caption_relevance(results, query): + """ + Compute caption relevance scores for a batch of search results. + + Args: + results: List of result dicts with 'title' and 'snippet' keys + query: Search query string + + Returns: + List of relevance scores corresponding to input results + """ + scores = [] + + for result in results: + # Combine title and snippet for more context + caption = f"{result.get('title', '')} {result.get('snippet', '')}" + score = compute_caption_relevance(caption, query) + scores.append(score) + + return scores diff --git a/test_caption_relevance.py b/test_caption_relevance.py new file mode 100644 index 0000000..9cd2b67 --- /dev/null +++ b/test_caption_relevance.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +""" +Simple test script to verify caption relevance scoring functionality. +""" + +import sys +sys.path.insert(0, '/Users/Brandon Shen/Documents/SearchVision') + +from src.utils.caption_relevance import ( + compute_caption_relevance, + compute_batch_caption_relevance +) + +# Test cases for caption relevance +test_cases = [ + { + "query": "dog", + "captions": [ + ("A beautiful golden retriever playing in the park", 0.8), # High relevance + ("Golden retriever on beach", 0.9), # Very high relevance + ("Cat sitting on a chair", 0.0), # No relevance + ("Dog training classes", 0.7), # Relevant but indirect + ("Fluffy dog breed guide", 0.8), # Relevant + ] + }, + { + "query": "cat sitting", + "captions": [ + ("Orange cat sitting on windowsill", 0.9), # Very high relevance + ("How to teach your cat to sit", 0.8), # Relevant + ("Dogs and their behavior", 0.0), # No relevance + ("Sitting meditation techniques", 0.2), # Partial relevance + ("Cat toys for active cats", 0.6), # Somewhat relevant + ] + }, +] + +print("=" * 70) +print("Caption Relevance Scoring Test") +print("=" * 70) + +for test_case in test_cases: + query = test_case["query"] + captions = test_case["captions"] + + print(f"\nQuery: '{query}'") + print("-" * 70) + + for caption, expected_approx in captions: + score = compute_caption_relevance(caption, query) + status = "✓" if abs(score - expected_approx) < 0.15 else "~" + print(f"{status} Caption: {caption[:50]:50} | Score: {score:.2f}") + +print("\n" + "=" * 70) +print("Batch Caption Relevance Test") +print("=" * 70) + +batch_results = [ + {"url": "http://example.com/1", "title": "Beautiful Golden Retriever", "snippet": "A friendly golden retriever dog"}, + {"url": "http://example.com/2", "title": "Cat Sleeping", "snippet": "Orange cat taking a nap"}, + {"url": "http://example.com/3", "title": "Dog Training Guide", "snippet": "Learn how to train your dog"}, +] + +query = "dog" +scores = compute_batch_caption_relevance(batch_results, query) + +print(f"\nQuery: '{query}'") +print("-" * 70) +for result, score in zip(batch_results, scores): + print(f"Title: {result['title']:40} | Score: {score:.2f}") + +print("\n" + "=" * 70) +print("All tests completed successfully!") +print("=" * 70)