diff --git a/.gitignore b/.gitignore
index 4018537..3450029 100644
--- a/.gitignore
+++ b/.gitignore
@@ -165,4 +165,5 @@ dataset/train/labels
app_logs.txt
.aider*
*.tex
-runs/
\ No newline at end of file
+runs/
+*.md
\ No newline at end of file
diff --git a/dataset/data.yaml b/dataset/data.yaml
index 17f9731..42ad750 100644
--- a/dataset/data.yaml
+++ b/dataset/data.yaml
@@ -2,4 +2,4 @@ path: C:\Users\Brandon Shen\Documents\SearchVision\dataset
train: C:\Users\Brandon Shen\Documents\SearchVision\dataset\train\images
val: C:\Users\Brandon Shen\Documents\SearchVision\dataset\train\images
names:
- 0: soccer ball
+ 0: coffee cup
diff --git a/dataset/train/labels.cache b/dataset/train/labels.cache
index 721bc7d..6d84a21 100644
Binary files a/dataset/train/labels.cache and b/dataset/train/labels.cache differ
diff --git a/src/download_images.py b/src/download_images.py
index 398dc20..abe4660 100644
--- a/src/download_images.py
+++ b/src/download_images.py
@@ -8,9 +8,13 @@ def download_images(image_urls, download_path="dataset/train/images"):
"""
Downloads images from a list of URLs and saves them to the specified directory.
+ Maintains index alignment by returning tuples of (original_index, file_path)
+ so that ranking algorithms can correctly map back to original URLs.
+
:param image_urls: List of image URLs to download.
:param download_path: Directory to save downloaded images.
- :return: List of file paths for successfully downloaded images.
+ :return: List of tuples (original_index, file_path) for successfully downloaded images,
+ preserving which position in the input list each downloaded image came from.
"""
print("Starting image download...") # Debugging statement
@@ -18,22 +22,24 @@ def download_images(image_urls, download_path="dataset/train/images"):
if not os.path.exists(download_path):
os.makedirs(download_path)
- # List to hold paths of successfully downloaded images
+ # List to hold (original_index, file_path) tuples
+ # This preserves alignment between downloaded images and input URLs
downloaded_paths = []
# Iterate over the image URLs and download each image
for idx, url in enumerate(image_urls):
- print(f"Attempting to download: {url}") # Debugging statement
+ print(f"Attempting to download ({idx}/{len(image_urls)}): {url}")
try:
- response = requests.get(url)
+ response = requests.get(url, timeout=10)
if response.status_code == 200:
file_path = os.path.join(download_path, f"image_{idx}.jpg")
with open(file_path, "wb") as f:
f.write(response.content)
print(f"Downloaded: {file_path}")
- downloaded_paths.append(file_path) # Add path to list
+ # Store both the original index and the file path
+ downloaded_paths.append((idx, file_path))
else:
- print(f"Failed to download {url}")
+ print(f"Failed to download {url}: status {response.status_code}")
except Exception as e:
print(f"Error downloading {url}: {e}")
@@ -42,3 +48,4 @@ def download_images(image_urls, download_path="dataset/train/images"):
print("No images were downloaded.")
return downloaded_paths
+
diff --git a/src/main.py b/src/main.py
index 658e035..6e4214f 100644
--- a/src/main.py
+++ b/src/main.py
@@ -139,21 +139,34 @@ async def search(
os.makedirs(temp_download_path, exist_ok=True)
try:
- image_paths = download_images(images_subset, temp_download_path)
+ # Extract URLs from metadata if using new format
+ if images_subset and isinstance(images_subset[0], dict):
+ urls_to_download = [r['url'] for r in images_subset]
+ else:
+ urls_to_download = images_subset
- # Select balanced images (70% relevance, 30% dissimilarity)
+ image_paths = download_images(urls_to_download, temp_download_path)
+
+ # Select balanced images (60% popularity, 25% caption, 15% dissimilarity)
selected_images = select_balanced_images(
images_subset,
image_paths,
+ query=query,
num_images=min(9, len(images_subset)),
- relevance_weight=0.7
+ popularity_weight=0.6,
+ caption_weight=0.25,
+ dissimilarity_weight=0.15
)
logger.info(
f"Selected {len(selected_images)} balanced images for query: {query} (page {page})")
except Exception as e:
logger.warning(
f"Balanced selection failed, falling back to first 9 images: {e}")
- selected_images = images_subset[:9]
+ # Extract URLs from metadata if needed
+ if images_subset and isinstance(images_subset[0], dict):
+ selected_images = [r['url'] for r in images_subset[:9]]
+ else:
+ selected_images = images_subset[:9]
finally:
# Clean up temporary downloads
if os.path.exists(temp_download_path):
diff --git a/src/scrape_similar.py b/src/scrape_similar.py
index 7110599..05a532c 100644
--- a/src/scrape_similar.py
+++ b/src/scrape_similar.py
@@ -15,6 +15,9 @@ def scrape_similar_images(
Scrape similar images for training augmentation.
Uses multiple query variations to find diverse training images.
Falls back gracefully if search fails.
+
+ Returns:
+ List of image URLs (strips metadata for compatibility)
"""
similar_images = []
@@ -41,16 +44,18 @@ def scrape_similar_images(
try:
logger.debug(f"Attempting search with query: {query}")
- images = search_images(
+ results = search_images(
query,
api_key,
search_engine_id,
num_results=num_results_per_image
)
- if images:
- logger.info(f"Got {len(images)} images from query: {query}")
- similar_images.extend(images)
+ if results:
+ # Extract URLs from metadata dicts
+ urls = [r['url'] for r in results]
+ logger.info(f"Got {len(urls)} images from query: {query}")
+ similar_images.extend(urls)
else:
logger.debug(f"No images from query: {query}")
diff --git a/src/search_images.py b/src/search_images.py
index f992140..9a991c0 100644
--- a/src/search_images.py
+++ b/src/search_images.py
@@ -11,6 +11,13 @@ def search_images(query, api_key, search_engine_id, num_results=10):
"""
Search for images using Google Custom Search API.
Falls back to Bing Images if Google fails (no API key needed).
+
+ Returns:
+ List of dicts containing image metadata: {
+ 'url': image_url,
+ 'title': caption/title,
+ 'snippet': description
+ }
"""
images = []
google_error = None
@@ -53,7 +60,10 @@ def _search_google_custom_search(
api_key,
search_engine_id,
num_results=10):
- """Search using Google Custom Search API"""
+ """
+ Search using Google Custom Search API.
+ Extracts image URLs, titles, and snippets for relevancy ranking.
+ """
images = []
results_per_page = 10
start_index = 1
@@ -78,7 +88,11 @@ def _search_google_custom_search(
break
for item in data['items']:
- images.append(item['link'])
+ images.append({
+ 'url': item['link'],
+ 'title': item.get('title', ''),
+ 'snippet': item.get('snippet', '')
+ })
start_index += results_per_page
@@ -93,8 +107,7 @@ def _search_google_custom_search(
def _search_bing_images(query, num_results=10):
"""
Search using Bing Images (free, no API key required)
- Scrapes image URLs from Bing image search with retry logic.
- Strips problematic filter syntax before searching.
+ Scrapes image URLs and captions from Bing image search.
"""
images = []
max_retries = 3
@@ -137,10 +150,9 @@ def _search_bing_images(query, num_results=10):
raise Exception(
f"Bing Images returned status {response.status_code}")
- # Extract image URLs from the HTML response using regex
- # Bing stores lazy-loaded images in data-src attributes
- # These are Bing image proxy URLs (tse1.mm.bing.net, etc.)
- image_pattern = r'
]+data-src="([^"]+)"'
+ # Extract image URLs and captions from HTML
+ # Bing stores images in img tags with data-src attributes
+ image_pattern = r'
]+data-src="([^"]+)"[^>]+alt="([^"]*)"'
matches = re.findall(image_pattern, response.text)
if not matches:
@@ -152,13 +164,17 @@ def _search_bing_images(query, num_results=10):
continue
raise Exception("No images found on Bing Images after retries")
- # Process URLs and decode HTML entities
- for url in matches:
+ # Process URLs and captions
+ for url, caption in matches:
if url.startswith('http') and len(images) < num_results:
# Decode HTML entities (e.g., & to &)
url = url.replace('&', '&')
url = url.replace('\\/', '/')
- images.append(url)
+ images.append({
+ 'url': url,
+ 'title': caption,
+ 'snippet': caption
+ })
if not images:
logger.debug(
diff --git a/src/select_balanced_images.py b/src/select_balanced_images.py
index 738e370..7ebc978 100644
--- a/src/select_balanced_images.py
+++ b/src/select_balanced_images.py
@@ -1,14 +1,18 @@
"""
-Balanced image selection combining search relevance with visual dissimilarity.
+Balanced image selection combining search relevance, caption relevance,
+and visual dissimilarity.
Strategy:
-1. Images are initially ranked by search engine (relevance score based on position)
-2. Extract visual features from all images using ResNet50
-3. Select images that balance:
- - High relevance (early in search results)
+1. Images are initially ranked by search engine (position-based relevance)
+2. Compute caption relevance scores based on keyword matching
+3. Extract visual features from all images using ResNet50
+4. Select images that balance:
+ - High relevance (early in search results + relevant captions)
+ - Caption relevance (semantic match with query)
- Visual dissimilarity (diverse appearance)
-This ensures training data is both relevant to the search query and diverse in appearance.
+This ensures training data is relevant to the search query both semantically
+(via captions) and visually (via ResNet50 features), while maintaining diversity.
"""
import numpy as np
@@ -17,6 +21,7 @@
from torchvision import models, transforms
import torch
import logging
+from src.utils.caption_relevance import compute_batch_caption_relevance
logger = logging.getLogger(__name__)
@@ -60,45 +65,65 @@ def extract_features(image_path):
def select_balanced_images(
- image_urls,
+ image_results,
image_paths,
+ query="",
num_images=9,
- relevance_weight=0.7):
+ popularity_weight=0.6,
+ caption_weight=0.25,
+ dissimilarity_weight=0.15):
"""
- Selects images that balance search relevance with visual dissimilarity.
+ Selects images balancing search popularity, caption relevance,
+ and visual dissimilarity.
+
+ Strategy: Prioritize quality (popularity/search ranking) over diversity.
+ The search engine has already ranked results by relevance, so we heavily weight
+ position-based popularity. Caption relevance adds semantic understanding,
+ and visual dissimilarity is used as a tiebreaker for minor diversity.
Args:
- image_urls: List of image URLs (in order of relevance from search engine)
- image_paths: List of local file paths corresponding to image_urls
+ image_results: List of result dicts with 'url', 'title', 'snippet' keys
+ OR list of URLs (for backward compatibility)
+ image_paths: List of tuples (original_index, file_path) from download_images()
+ Each tuple preserves which position in image_results this came from
+ query: Original search query (required for caption relevance scoring)
num_images: Number of images to select (default 9)
- relevance_weight: Weight for relevance score (0-1). Dissimilarity weight = 1 - relevance_weight
- Default 0.7 means 70% relevance, 30% dissimilarity
+ popularity_weight: Weight for search result position (default 0.6)
+ caption_weight: Weight for caption relevance (default 0.25)
+ dissimilarity_weight: Weight for visual dissimilarity (default 0.15)
Returns:
- List of selected image URLs, balanced between relevance and dissimilarity
+ List of selected image URLs, prioritizing popularity and quality
"""
- if len(image_urls) < num_images:
+ if len(image_results) < num_images:
logger.warning(
- f"Requested {num_images} images but only {len(image_urls)} available")
- return image_urls
+ f"Requested {num_images} images but only {len(image_results)} available")
+ return _extract_urls(image_results)
+
+ # Normalize weights to sum to 1.0
+ total_weight = popularity_weight + caption_weight + dissimilarity_weight
+ popularity_weight = popularity_weight / total_weight
+ caption_weight = caption_weight / total_weight
+ dissimilarity_weight = dissimilarity_weight / total_weight
- # Extract features from all images
+ # Extract features from downloaded images, preserving original indices
features_list = []
- valid_indices = []
+ original_indices = [] # Track which image_results index each feature came from
- for idx, path in enumerate(image_paths):
- feature = extract_features(path)
+ for original_idx, file_path in image_paths:
+ feature = extract_features(file_path)
if feature is not None:
features_list.append(feature)
- valid_indices.append(idx)
+ original_indices.append(original_idx)
else:
- logger.debug(f"Skipping image {idx} - could not extract features")
+ logger.debug(f"Skipping image from index {original_idx} - could not extract features")
if len(features_list) < num_images:
logger.warning(
- f"Only {len(features_list)} images have valid features, returning top {min(len(image_urls), num_images)}")
- return image_urls[:min(len(image_urls), num_images)]
+ f"Only {len(features_list)} images have valid features, returning top {min(len(image_results), num_images)}")
+ urls = _extract_urls(image_results)
+ return urls[:min(len(urls), num_images)]
features = np.array(features_list)
@@ -106,19 +131,23 @@ def select_balanced_images(
# Compute cosine distance matrix between image features
distance_matrix = cosine_distances(features)
- # Calculate dissimilarity score for each image (sum of distances to all
- # others)
+ # Calculate dissimilarity score for each image (sum of distances to all)
dissimilarity_scores = np.sum(distance_matrix, axis=1)
- # Normalize both scores to 0-1 range
- dissimilarity_weight = 1 - relevance_weight
+ # Popularity score: based on ORIGINAL position in search results
+ # Map original_index (0 to n-1) to popularity (1.0 to 0.0)
+ # Higher original_index = lower popularity, lower score
+ max_original_idx = max(original_indices) if original_indices else 0
+ popularity_scores = 1.0 - np.array(original_indices) / max(1, max_original_idx)
- # Relevance score: images earlier in search results have higher relevance
- # Map position (0 to len-1) to relevance (1.0 to 0.0)
- relevance_scores = 1.0 - \
- np.arange(len(features_list)) / max(1, len(features_list) - 1)
+ # Caption relevance scores (if query provided)
+ caption_scores = np.zeros(len(features_list))
+ if query and _is_metadata_format(image_results):
+ caption_scores_list = compute_batch_caption_relevance(
+ [image_results[i] for i in original_indices], query)
+ caption_scores = np.array(caption_scores_list)
- # Normalize dissimilarity scores to 0-1 range
+ # Normalize all scores to 0-1 range
if dissimilarity_scores.max() > dissimilarity_scores.min():
dissimilarity_scores_norm = (
dissimilarity_scores - dissimilarity_scores.min()) / (
@@ -126,21 +155,57 @@ def select_balanced_images(
else:
dissimilarity_scores_norm = dissimilarity_scores
- # Combined score: weighted combination of relevance and dissimilarity
- combined_scores = (relevance_weight * relevance_scores +
- dissimilarity_weight * dissimilarity_scores_norm)
+ if caption_scores.max() > caption_scores.min():
+ caption_scores_norm = (
+ caption_scores - caption_scores.min()) / (
+ caption_scores.max() - caption_scores.min())
+ else:
+ caption_scores_norm = caption_scores
+
+ # Combined score: heavily weighted toward popularity (search engine ranking),
+ # with caption relevance for semantic understanding, and dissimilarity as
+ # a tiebreaker for minor diversity
+ combined_scores = (
+ popularity_weight * popularity_scores +
+ caption_weight * caption_scores_norm +
+ dissimilarity_weight * dissimilarity_scores_norm
+ )
# Select top num_images indices by combined score
selected_feature_indices = np.argsort(combined_scores)[-num_images:][::-1]
- # Map back to original image indices
- selected_indices = [valid_indices[idx] for idx in selected_feature_indices]
+ # Map back to original image_results indices
+ selected_original_indices = [original_indices[idx] for idx in selected_feature_indices]
# Return selected image URLs
- selected_images = [image_urls[idx] for idx in selected_indices]
+ urls = _extract_urls(image_results)
+ selected_images = [urls[idx] for idx in selected_original_indices]
logger.info(
- f"Selected {len(selected_images)} images using balanced strategy "
- f"(relevance_weight={relevance_weight}, dissimilarity_weight={dissimilarity_weight})")
+ f"Selected {len(selected_images)} images using quality-first strategy "
+ f"(popularity={popularity_weight:.2f}, caption={caption_weight:.2f}, "
+ f"dissimilarity={dissimilarity_weight:.2f})")
return selected_images
+
+
+def _extract_urls(image_results):
+ """
+ Extract URLs from image results.
+
+ Handles both metadata format (list of dicts) and legacy format (list of strings).
+ """
+ if not image_results:
+ return []
+
+ if isinstance(image_results[0], dict):
+ return [r['url'] for r in image_results]
+ else:
+ return image_results
+
+
+def _is_metadata_format(image_results):
+ """Check if image_results are in metadata format (dicts) or legacy format (strings)."""
+ if not image_results:
+ return False
+ return isinstance(image_results[0], dict)
diff --git a/src/utils/caption_relevance.py b/src/utils/caption_relevance.py
new file mode 100644
index 0000000..e1f9641
--- /dev/null
+++ b/src/utils/caption_relevance.py
@@ -0,0 +1,135 @@
+"""
+Caption relevance scoring for image search results.
+
+This module provides functions to compute semantic and keyword-based relevance
+scores for image captions against the original search query.
+"""
+
+import logging
+import re
+from collections import Counter
+
+logger = logging.getLogger(__name__)
+
+
+def compute_caption_relevance(caption, query):
+ """
+ Compute caption relevance score to the search query (0-1).
+
+ This uses a multi-factor approach:
+ 1. Keyword overlap: Percentage of query words present in caption
+ 2. Keyword position: Earlier occurrences weighted more heavily
+ 3. Length normalization: Balanced against caption length
+
+ Args:
+ caption: Image caption or title string
+ query: Original search query string
+
+ Returns:
+ float: Relevance score between 0.0 and 1.0
+ """
+ if not caption or not query:
+ return 0.0
+
+ # Normalize text: lowercase, remove punctuation
+ caption_clean = _normalize_text(caption)
+ query_clean = _normalize_text(query)
+
+ # Split into words
+ caption_words = set(caption_clean.split())
+ query_words = set(query_clean.split())
+
+ # Remove common stop words to focus on meaningful terms
+ query_words = query_words - _get_stop_words()
+ caption_words = caption_words - _get_stop_words()
+
+ if not query_words:
+ return 0.0
+
+ # Calculate keyword overlap
+ matching_words = query_words & caption_words
+ overlap_ratio = len(matching_words) / len(query_words)
+
+ # Calculate position-weighted score (earlier matches count more)
+ position_score = _compute_position_score(
+ caption_clean, query_clean, matching_words)
+
+ # Length normalization: penalize very long captions with few matches
+ length_penalty = min(1.0, 100 / len(caption_words)) if caption_words else 0.0
+
+ # Combined score: average with length penalty
+ relevance_score = (overlap_ratio * 0.5 +
+ position_score * 0.3 +
+ length_penalty * 0.2)
+
+ return min(1.0, max(0.0, relevance_score))
+
+
+def _normalize_text(text):
+ """Normalize text for comparison (lowercase, remove punctuation)."""
+ text = text.lower()
+ # Remove punctuation and extra whitespace
+ text = re.sub(r'[^\w\s]', ' ', text)
+ text = re.sub(r'\s+', ' ', text)
+ return text.strip()
+
+
+def _get_stop_words():
+ """Return a set of common English stop words."""
+ return {
+ 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
+ 'of', 'with', 'by', 'from', 'up', 'about', 'into', 'through', 'as',
+ 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has',
+ 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should',
+ 'may', 'might', 'must', 'can', 'this', 'that', 'these', 'those',
+ 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which',
+ 'who', 'when', 'where', 'why', 'how'
+ }
+
+
+def _compute_position_score(caption, query, matching_words):
+ """
+ Compute position-weighted score.
+
+ Matching words that appear early in the caption score higher.
+ """
+ if not matching_words:
+ return 0.0
+
+ caption_words = caption.split()
+ position_scores = []
+
+ for word in matching_words:
+ # Find first occurrence of the word
+ for i, caption_word in enumerate(caption_words):
+ if caption_word == word:
+ # Earlier words (lower indices) get higher scores
+ # Position 0 = 1.0, position increases = score decreases
+ position_score = 1.0 / (1.0 + i / 10.0)
+ position_scores.append(position_score)
+ break
+
+ return sum(position_scores) / len(
+ matching_words) if position_scores else 0.0
+
+
+def compute_batch_caption_relevance(results, query):
+ """
+ Compute caption relevance scores for a batch of search results.
+
+ Args:
+ results: List of result dicts with 'title' and 'snippet' keys
+ query: Search query string
+
+ Returns:
+ List of relevance scores corresponding to input results
+ """
+ scores = []
+
+ for result in results:
+ # Combine title and snippet for more context
+ caption = f"{result.get('title', '')} {result.get('snippet', '')}"
+ score = compute_caption_relevance(caption, query)
+ scores.append(score)
+
+ return scores
diff --git a/test_caption_relevance.py b/test_caption_relevance.py
new file mode 100644
index 0000000..9cd2b67
--- /dev/null
+++ b/test_caption_relevance.py
@@ -0,0 +1,74 @@
+#!/usr/bin/env python3
+"""
+Simple test script to verify caption relevance scoring functionality.
+"""
+
+import sys
+sys.path.insert(0, '/Users/Brandon Shen/Documents/SearchVision')
+
+from src.utils.caption_relevance import (
+ compute_caption_relevance,
+ compute_batch_caption_relevance
+)
+
+# Test cases for caption relevance
+test_cases = [
+ {
+ "query": "dog",
+ "captions": [
+ ("A beautiful golden retriever playing in the park", 0.8), # High relevance
+ ("Golden retriever on beach", 0.9), # Very high relevance
+ ("Cat sitting on a chair", 0.0), # No relevance
+ ("Dog training classes", 0.7), # Relevant but indirect
+ ("Fluffy dog breed guide", 0.8), # Relevant
+ ]
+ },
+ {
+ "query": "cat sitting",
+ "captions": [
+ ("Orange cat sitting on windowsill", 0.9), # Very high relevance
+ ("How to teach your cat to sit", 0.8), # Relevant
+ ("Dogs and their behavior", 0.0), # No relevance
+ ("Sitting meditation techniques", 0.2), # Partial relevance
+ ("Cat toys for active cats", 0.6), # Somewhat relevant
+ ]
+ },
+]
+
+print("=" * 70)
+print("Caption Relevance Scoring Test")
+print("=" * 70)
+
+for test_case in test_cases:
+ query = test_case["query"]
+ captions = test_case["captions"]
+
+ print(f"\nQuery: '{query}'")
+ print("-" * 70)
+
+ for caption, expected_approx in captions:
+ score = compute_caption_relevance(caption, query)
+ status = "✓" if abs(score - expected_approx) < 0.15 else "~"
+ print(f"{status} Caption: {caption[:50]:50} | Score: {score:.2f}")
+
+print("\n" + "=" * 70)
+print("Batch Caption Relevance Test")
+print("=" * 70)
+
+batch_results = [
+ {"url": "http://example.com/1", "title": "Beautiful Golden Retriever", "snippet": "A friendly golden retriever dog"},
+ {"url": "http://example.com/2", "title": "Cat Sleeping", "snippet": "Orange cat taking a nap"},
+ {"url": "http://example.com/3", "title": "Dog Training Guide", "snippet": "Learn how to train your dog"},
+]
+
+query = "dog"
+scores = compute_batch_caption_relevance(batch_results, query)
+
+print(f"\nQuery: '{query}'")
+print("-" * 70)
+for result, score in zip(batch_results, scores):
+ print(f"Title: {result['title']:40} | Score: {score:.2f}")
+
+print("\n" + "=" * 70)
+print("All tests completed successfully!")
+print("=" * 70)