Source: examples/vector_database
Building Large-Scale Image Search using VectorDB & OpenAI CLIP#
Large-Scale Image Search#
As the volume of image data grows, the need for efficient and powerful search methods becomes critical. Traditional keyword-based or metadata-based search often fails to capture the full semantic meaning in images. A vector database enables semantic search: you can find images that conceptually match a query (e.g., “a photo of a cloud”) rather than relying on textual tags.
In particular:
Scalability: Modern apps can deal with millions or billions of images, making typical database solutions slower or harder to manage.
Flexibility: Storing image embeddings as vectors allows you to adapt to different search use-cases, from “find similar products” to “find images with specific objects or styles.”
Performance: Vector databases are optimized for nearest neighbor queries in high-dimensional spaces, enabling real-time or near real-time search on large datasets.
SkyPilot streamlines the process of running such large-scale jobs in the cloud. It abstracts away much of the complexity of managing infrastructure and helps you run compute-intensive tasks efficiently and cost-effectively through managed jobs.
Please find the complete blog post here
Step 0: Set Up The Environment#
Install the following Prerequisites:
SkyPilot: Make sure you have SkyPilot installed and
sky check
should succeed. Refer to SkyPilot’s documentation for instructions.Hugging Face Token: To download dataset from Hugging Face Hub, you will need your token. Follow the steps below to configure your token.
Setup Huggingface token in ~/.env
HF_TOKEN=hf_xxxxx
or set up the environment variable HF_TOKEN
.
Step 1: Compute Vectors from Image Data with OpenAI CLIP#
You need to convert images into vector representations (embeddings) so they can be stored in a vector database. Models like CLIP by OpenAI learn powerful representations that map images and text into the same embedding space. This allows for semantic similarity calculations, making queries like “a photo of a cloud” match relevant images.
Use the following command to launch a job that processes your image dataset and computes the CLIP embeddings:
python3 batch_compute_vectors.py
This will automatically find available machines to compute the vectors. Expect:
...
(clip-batch-compute-vectors, pid=2523) 2025-01-27 23:57:27,387 - root - INFO - Saved partition 2 to /output/embeddings_90000_100000.parquet_part_2/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-27 23:59:39,720 - root - INFO - Saved partition 3 to /output/embeddings_90000_100000.parquet_part_3/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-28 00:01:56,707 - root - INFO - Saved partition 4 to /output/embeddings_90000_100000.parquet_part_4/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-28 00:04:12,200 - root - INFO - Saved partition 5 to /output/embeddings_90000_100000.parquet_part_5/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-28 00:06:25,009 - root - INFO - Saved partition 6 to /output/embeddings_90000_100000.parquet_part_6/data.parquet
...
You can also use sky jobs queue
and sky jobs dashboard
to see the status of jobs. Figure below shows our jobs are launched across different regions:
Step 2: Construct the Vector Database from Computed Embeddings#
Once you have the image embeddings, you need a specialized engine to perform rapid similarity searches at scale. In this example, we use ChromaDB to store and query the embeddings. This step ingests the embeddings from Step 1 into a vector database to enable real-time or near real-time search over millions of vectors.
To construct the database from embeddings:
sky jobs launch build_vectordb.yaml
This process the generated clip embeddings in batches, generating output:
(vectordb-build, pid=2457) INFO:__main__:Processing /clip_embeddings/embeddings_0_500.parquet_part_0/data.parquet
Processing batches: 100%|██████████| 1/1 [00:00<00:00, 1.19it/s]
Processing files: 92%|█████████▏| 11/12 [00:02<00:00, 5.36it/s]INFO:__main__:Processing /clip_embeddings/embeddings_500_1000.parquet_part_0/data.parquet
Processing batches: 100%|██████████| 1/1 [00:02<00:00, 2.39s/it]
Processing files: 100%|██████████| 12/12 [00:05<00:00, 2.04it/s]/1 [00:00<?, ?it/s]
Step 3: Serve the Constructed Vector Database#
To serve the constructed database, you expose an API endpoint that other applications (or your local client) can call to perform semantic search. Querying allows you to confirm that your database is working and retrieve semantic matches for a given text query. You can integrate this endpoint into larger applications (like an image search engine or recommendation system).
To serve the constructed database:
sky launch -c vecdb_serve serve_vectordb.yaml
which runs the hosted vector database service. Alternatively, you can run
sky serve up serve_vectordb.yaml -n vectordb
This will deploy your vector database as a service on a cloud instance and allow you to interact with it via a public endpoint. Sky Serve facilitates automatic health checks and scaling of the service.
To query the constructed database,
If you run through sky launch
, use
sky status --ip vecdb_serve
deployed cluster.
If you run through sky serve
, you may run
sky serve status vectordb --endpoint
to get the endpoint address of the service.
Included files#
batch_compute_vectors.py
"""
Use skypilot to launch managed jobs that will run the embedding calculation.
This script is responsible for splitting the input dataset up among several workers,
then using skypilot to launch managed jobs for each worker. We use compute_vectors.yaml
to define the managed job info.
"""
#!/usr/bin/env python3
import argparse
import os
import sky
def calculate_job_range(start_idx: int, end_idx: int, job_rank: int,
total_jobs: int) -> tuple[int, int]:
"""Calculate the range of indices this job should process.
Args:
start_idx: Global start index
end_idx: Global end index
job_rank: Current job's rank (0-based)
total_jobs: Total number of jobs
Returns:
Tuple of [job_start_idx, job_end_idx)
"""
total_range = end_idx - start_idx
chunk_size = total_range // total_jobs
remainder = total_range % total_jobs
# Distribute remainder across first few jobs
job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder)
if job_rank < remainder:
chunk_size += 1
job_end = job_start + chunk_size
return job_start, job_end
def main():
parser = argparse.ArgumentParser(
description='Launch batch CLIP inference jobs')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Global start index in dataset')
parser.add_argument('--end-idx',
type=int,
default=1000000,
help='Global end index in dataset, not inclusive')
parser.add_argument('--num-jobs',
type=int,
default=100,
help='Number of jobs to partition the work across')
parser.add_argument('--env-path',
type=str,
default='~/.env',
help='Path to the environment file')
args = parser.parse_args()
# Try to get HF_TOKEN from environment first, then ~/.env file
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
env_path = os.path.expanduser(args.env_path)
if os.path.exists(env_path):
with open(env_path) as f:
for line in f:
if line.startswith('HF_TOKEN='):
hf_token = line.strip().split('=')[1]
break
if not hf_token:
raise ValueError("HF_TOKEN not found in ~/.env or environment variable")
# Load the task template
task = sky.Task.from_yaml('compute_vectors.yaml')
# Launch jobs for each partition
for job_rank in range(args.num_jobs):
# Calculate index range for this job
job_start, job_end = calculate_job_range(args.start_idx, args.end_idx,
job_rank, args.num_jobs)
# Update environment variables for this job
task_copy = task.update_envs({
'START_IDX': job_start,
'END_IDX': job_end,
'HF_TOKEN': hf_token,
})
sky.jobs.launch(
task_copy,
name=f'vector-compute-{job_start}-{job_end}',
)
if __name__ == '__main__':
main()
build_vectordb.yaml
name: vectordb-build
workdir: .
file_mounts:
/clip_embeddings:
name: sky-demo-embedding
# this needs to be the same as the source in the compute_vectors.yaml
mode: MOUNT
/vectordb:
name: sky-vectordb
# this needs to be the same as the source in the serve_vectordb.yaml
mode: MOUNT
/images:
name: sky-demo-image
# this needs to be the same as the source in compute_vectors.yaml
mode: MOUNT
setup: |
pip install chromadb pandas tqdm pyarrow
run: |
python scripts/build_vectordb.py \
--collection-name clip_embeddings \
--persist-dir /vectordb/chroma \
--embeddings-dir /clip_embeddings \
--batch-size 1000
compute_vectors.yaml
name: clip-batch-compute-vectors
workdir: .
resources:
accelerators:
# ordered by pricing (cheapest to most expensive)
T4: 1
L4: 1
A10G: 1
A10: 1
V100: 1
memory: 32+
any_of:
- use_spot: true
- use_spot: false
num_nodes: 1
file_mounts:
/output:
name: sky-demo-embedding
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
/images:
name: sky-demo-image
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
envs:
# These env vars are required but should be passed in at launch time.
HF_TOKEN: ''
START_IDX: ''
END_IDX: ''
setup: |
pip install numpy==1.26.4
pip install torch==2.5.1 torchvision==0.20.1 ftfy regex tqdm
pip install datasets webdataset requests Pillow open_clip_torch
pip install fastapi uvicorn aiohttp pandas pyarrow tenacity
run: |
python scripts/compute_vectors.py \
--output-path "/output/embeddings_${START_IDX}_${END_IDX}.parquet" \
--start-idx ${START_IDX} \
--end-idx ${END_IDX} \
--batch-size 64 \
--checkpoint-size 1000
echo "Processing complete. Results saved in node-specific files under /output/"
scripts/build_vectordb.py
"""
This script is responsible for building the vector database from the mounted bucket and saving
it to another mounted bucket.
"""
import argparse
import base64
from concurrent.futures import as_completed
from concurrent.futures import ProcessPoolExecutor
import glob
import logging
import multiprocessing
import os
import pickle
import shutil
import tempfile
import chromadb
import numpy as np
import pandas as pd
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def list_local_parquet_files(mount_path: str, prefix: str) -> list:
"""List all parquet files in the mounted S3 directory."""
search_path = os.path.join(mount_path, prefix, '**/*.parquet')
parquet_files = glob.glob(search_path, recursive=True)
return parquet_files
def process_parquet_file(args):
"""Process a single parquet file and return the processed data."""
parquet_file, batch_size = args
try:
results = []
df = pd.read_parquet(parquet_file)
# Process in batches
for i in range(0, len(df), batch_size):
batch_df = df.iloc[i:i + batch_size]
# Extract data from DataFrame and unpack the pickled data
ids = [str(idx) for idx in batch_df['idx']]
unpacked_data = [pickle.loads(row) for row in batch_df['output']]
images_base64, embeddings = zip(*unpacked_data)
results.append((ids, embeddings, images_base64))
return results
except Exception as e:
logger.error(f'Error processing file {parquet_file}: {str(e)}')
return None
def main():
parser = argparse.ArgumentParser(
description='Build ChromaDB from mounted S3 parquet files')
parser.add_argument('--collection-name',
type=str,
default='clip_embeddings',
help='ChromaDB collection name')
parser.add_argument('--persist-dir',
type=str,
default='/vectordb/chroma',
help='Directory to persist ChromaDB')
parser.add_argument(
'--batch-size',
type=int,
default=1000,
help='Batch size for processing, this needs to fit in memory')
parser.add_argument('--embeddings-dir',
type=str,
default='/clip_embeddings',
help='Path to mounted bucket containing parquet files')
parser.add_argument(
'--prefix',
type=str,
default='',
help='Prefix path within mounted bucket to search for parquet files')
args = parser.parse_args()
# Create a temporary directory for building the database. The
# mounted bucket does not support append operation, so build in
# the tmpdir and then copy it to the final location.
with tempfile.TemporaryDirectory() as temp_dir:
logger.info(f'Using temporary directory: {temp_dir}')
# Initialize ChromaDB in temporary directory
client = chromadb.PersistentClient(path=temp_dir)
# Create or get collection for chromadb
# it attempts to create a collection with the same name
# if it already exists, it will get the collection
try:
collection = client.create_collection(
name=args.collection_name,
metadata={'description': 'CLIP embeddings from dataset'})
logger.info(f'Created new collection: {args.collection_name}')
except ValueError:
collection = client.get_collection(name=args.collection_name)
logger.info(f'Using existing collection: {args.collection_name}')
# List parquet files from mounted directory
parquet_files = list_local_parquet_files(args.embeddings_dir,
args.prefix)
logger.info(f'Found {len(parquet_files)} parquet files')
# Process files in parallel
max_workers = max(1,
multiprocessing.cpu_count() - 1) # Leave one CPU free
logger.info(f'Processing files using {max_workers} workers')
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all files for processing
future_to_file = {
executor.submit(process_parquet_file, (file, args.batch_size)):
file for file in parquet_files
}
# Process results as they complete
for future in tqdm(as_completed(future_to_file),
total=len(parquet_files),
desc='Processing files'):
file = future_to_file[future]
try:
results = future.result()
if results:
for ids, embeddings, images_paths in results:
collection.add(ids=list(ids),
embeddings=list(embeddings),
documents=list(images_paths))
except Exception as e:
logger.error(f'Error processing file {file}: {str(e)}')
continue
logger.info('Vector database build complete!')
logger.info(f'Total documents in collection: {collection.count()}')
# Copy the completed database to the final location
logger.info(f'Copying database to final location: {args.persist_dir}')
if os.path.exists(args.persist_dir):
logger.info('Removing existing database directory')
shutil.rmtree(args.persist_dir)
shutil.copytree(temp_dir, args.persist_dir)
logger.info('Database copy complete!')
if __name__ == '__main__':
main()
scripts/compute_vectors.py
"""
This script is responsible for computing the embeddings for the ImageNet dataset.
"""
import abc
import asyncio
import base64
from io import BytesIO
import logging
import os
from pathlib import Path
import pickle
import shutil
from typing import (Any, AsyncIterator, Dict, Generic, List, Optional, Tuple,
TypeVar)
import numpy as np
import pandas as pd
from PIL import Image
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
class BatchProcessor():
"""Process ImageNet images with CLIP.
This script is responsible for computing the embeddings for the ImageNet dataset.
1. setup_model initializes the model
2. get_dataset_iterator will yield individual items from the dataset
3. do_data_loading will get an item from the dataset iterator and do any preprocessing
4. the loaded items will be batched and handed to do_batch_processing for the ultimate processing
"""
def __init__(self,
output_path: str,
images_path: str = '/images',
model_name: str = 'ViT-bigG-14',
dataset_name: str = 'ILSVRC/imagenet-1k',
pretrained: str = 'laion2b_s39b_b160k',
device: Optional[str] = None,
split: str = 'train',
streaming: bool = True,
batch_size: int = 32,
checkpoint_size: int = 100,
start_idx: int = 0,
end_idx: Optional[int] = None):
self.output_path = Path(output_path) # Convert to Path object
self.images_path = Path(images_path) # Path to store images
self.batch_size = batch_size
self.checkpoint_size = checkpoint_size
self.start_idx = start_idx
self.end_idx = end_idx
self._current_batch = []
# Create images directory if it doesn't exist
self.images_path.mkdir(parents=True, exist_ok=True)
# CLIP-specific attributes
self.model_name = model_name
self.pretrained = pretrained
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.dataset_name = dataset_name
self.split = split
self.streaming = streaming
self.model = None
self.preprocess = None
self.partition_counter = 0
async def setup_model(self):
"""Set up the CLIP model."""
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms(
self.model_name, pretrained=self.pretrained, device=self.device)
self.model = model
self.preprocess = preprocess
async def get_dataset_iterator(self) -> AsyncIterator[Tuple[int, Any]]:
"""Load data from a HuggingFace dataset."""
from datasets import load_dataset
dataset = load_dataset(self.dataset_name,
streaming=self.streaming,
trust_remote_code=True)[self.split]
if self.start_idx > 0:
dataset = dataset.skip(self.start_idx)
for idx, item in enumerate(dataset, start=self.start_idx):
if self.end_idx and idx >= self.end_idx:
break
yield idx, item
async def do_data_loading(
self) -> AsyncIterator[Tuple[int, Tuple[torch.Tensor, Any]]]:
"""Load and preprocess ImageNet images."""
async for idx, item in self.get_dataset_iterator():
try:
# ImageNet provides PIL Images directly
tensor = self.preprocess(item['image'])
if tensor is not None:
# Pass through both the tensor and original image
yield idx, (tensor, item['image'])
except Exception as e:
logging.debug(
f'Error preprocessing image at index {idx}: {str(e)}')
def save_image(self, idx: int, image: Image.Image) -> str:
"""Save image to the mounted bucket and return its path."""
# Create a subdirectory based on the first few digits of the index to avoid too many files in one directory
subdir = str(idx // 100000).zfill(4)
save_dir = self.images_path / subdir
save_dir.mkdir(parents=True, exist_ok=True)
# Save image with index as filename
image_path = save_dir / f'{idx}.jpg'
image.save(image_path, format='JPEG', quality=95)
# Return relative path from images root
return str(Path(subdir) / f'{idx}.jpg')
async def do_batch_processing(
self, batch: List[Tuple[int, Tuple[torch.Tensor, Any]]]
) -> List[Tuple[int, bytes]]:
"""Process a batch of images through CLIP."""
if self.model is None:
await self.setup_model()
# Unpack the batch
indices, batch_data = zip(*batch)
model_inputs, original_images = zip(*batch_data)
# Stack inputs into a batch
batch_tensor = torch.stack(model_inputs).to(self.device)
# Run inference
with torch.no_grad():
features = self.model.encode_image(batch_tensor)
features /= features.norm(dim=-1, keepdim=True)
# Convert to numpy arrays
embeddings = features.cpu().numpy()
# Save images and store their paths
image_paths = {}
for idx, img in zip(indices, original_images):
image_path = self.save_image(idx, img)
image_paths[idx] = image_path
# Return both embeddings and image paths
return [(idx, pickle.dumps((image_paths[idx], arr)))
for idx, arr in zip(indices, embeddings)]
async def find_existing_progress(self) -> Tuple[int, int]:
"""
Find the highest processed index and partition counter from existing files.
Returns:
Tuple[int, int]: (highest_index, next_partition_number)
"""
if not self.output_path.parent.exists():
self.output_path.parent.mkdir(parents=True, exist_ok=True)
return self.start_idx, 0
partition_files = list(
self.output_path.parent.glob(
f'{self.output_path.stem}_part_*.parquet'))
print(f'Partition files: {partition_files}')
if not partition_files:
return self.start_idx, 0
max_idx = self.start_idx
max_partition = -1
for file in partition_files:
# Extract partition number from filename
try:
partition_num = int(file.stem.split('_part_')[1])
max_partition = max(max_partition, partition_num)
# Read the file and find highest index
df = pd.read_parquet(file)
if not df.empty:
max_idx = max(max_idx, df['idx'].max())
except Exception as e:
logging.warning(f'Error processing file {file}: {e}')
return max_idx, max_partition + 1
def save_results_to_parquet(self, results: list):
"""Save results to a parquet file with atomic write."""
if not results:
return
df = pd.DataFrame(results, columns=['idx', 'output'])
final_path = f'{self.output_path}_part_{self.partition_counter}.parquet'
temp_path = f'/tmp/{self.partition_counter}.tmp'
# Write to temporary file first
df.to_parquet(temp_path, engine='pyarrow', index=False)
# Copy from temp to final destination
shutil.copy2(temp_path, final_path)
os.remove(temp_path) # Clean up temp file
logging.info(
f'Saved partition {self.partition_counter} to {final_path} with {len(df)} rows'
)
self.partition_counter += 1
async def run(self):
"""
Run the batch processing pipeline with recovery support.
"""
# Initialize the model
if self.model is None:
await self.setup_model()
# Find existing progress
resume_idx, self.partition_counter = await self.find_existing_progress()
self.start_idx = max(self.start_idx, resume_idx + 1)
logging.info(
f'Starting processing from index {self.start_idx} (partition {self.partition_counter})'
)
results = []
async for idx, input_data in self.do_data_loading():
self._current_batch.append((idx, input_data))
if len(self._current_batch) >= self.batch_size:
batch_results = await self.do_batch_processing(
self._current_batch)
results.extend(batch_results)
self._current_batch = []
if len(results) >= self.checkpoint_size:
self.save_results_to_parquet(results)
results.clear()
# Process any remaining items in the batch
if self._current_batch:
batch_results = await self.do_batch_processing(self._current_batch)
results.extend(batch_results)
# Write the final partition if there are any leftover results
if results:
self.save_results_to_parquet(results)
async def main():
"""Example usage of the batch processing framework."""
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Run CLIP batch processing on ImageNet')
parser.add_argument('--output-path',
type=str,
default='embeddings.parquet',
help='Path to output parquet file')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Starting index in dataset')
parser.add_argument('--end-idx',
type=int,
default=10000,
help='Ending index in dataset')
parser.add_argument('--batch-size',
type=int,
default=50,
help='Batch size for processing')
parser.add_argument('--checkpoint-size',
type=int,
default=100,
help='Number of results before checkpointing')
parser.add_argument('--model-name',
type=str,
default='ViT-bigG-14',
help='CLIP model name')
parser.add_argument('--images-path',
type=str,
default='/images',
help='Path to store images')
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Initialize processor
processor = BatchProcessor(output_path=args.output_path,
start_idx=args.start_idx,
end_idx=args.end_idx,
batch_size=args.batch_size,
checkpoint_size=args.checkpoint_size,
model_name=args.model_name,
images_path=args.images_path)
# Run processing
await processor.run()
if __name__ == '__main__':
asyncio.run(main())
scripts/serve_vectordb.py
"""
This script is responsible for serving the vector database.
"""
import argparse
import base64
import logging
import os
from pathlib import Path
from typing import List, Optional
import chromadb
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import FileResponse
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import numpy as np
import open_clip
from pydantic import BaseModel
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title='Vector Database Search API')
# Global variables for model and database
model = None
tokenizer = None
collection = None
device = None
images_dir = None
class SearchQuery(BaseModel):
text: str
n_results: Optional[int] = 5
class SearchResult(BaseModel):
image_path: str
similarity: float
def encode_text(text: str, model_name: str = 'ViT-bigG-14') -> np.ndarray:
"""Encode text using CLIP model."""
global model, tokenizer, device
# Tokenize and encode
text_tokens = tokenizer([text]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_tokens)
# Normalize the features
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()
def query_collection(query_embedding: np.ndarray,
n_results: int = 5) -> List[SearchResult]:
"""Query the collection and return top matches with scores."""
global collection
results = collection.query(query_embeddings=query_embedding.tolist(),
n_results=n_results,
include=['metadatas', 'distances', 'documents'])
# Get image paths and distances
image_paths = results['documents'][0]
distances = results['distances'][0]
# Convert distances to similarities (cosine similarity = 1 - distance/2)
similarities = [1 - (d / 2) for d in distances]
return [
SearchResult(image_path=img_path, similarity=similarity)
for img_path, similarity in zip(image_paths, similarities)
]
@app.post('/search', response_model=List[SearchResult])
async def search(query: SearchQuery):
"""Search endpoint that takes a text query and returns similar images."""
try:
# Encode the query text
query_embedding = encode_text(query.text)
# Query the collection
results = query_collection(query_embedding, query.n_results)
return results
except Exception as e:
logger.error(f'Error processing query: {str(e)}')
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/{subpath:path}')
async def get_image(subpath: str):
"""Serve an image from the mounted bucket."""
image_path = os.path.join(images_dir, subpath)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail='Image not found')
return FileResponse(image_path, media_type='image/jpeg')
@app.get('/health')
async def health_check():
"""Health check endpoint."""
return {
'status': 'healthy',
'collection_size': collection.count() if collection else 0
}
@app.get('/', response_class=HTMLResponse)
async def get_search_page():
"""Serve a simple search interface."""
return """
<html>
<head>
<title>Image Search</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
line-height: 1.6;
background-color: #f5f5f5;
color: #333;
min-height: 100vh;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
}
.search-container {
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
text-align: center;
}
h1 {
color: #2c3e50;
margin-bottom: 1.5rem;
font-size: 2.5rem;
}
.search-box {
display: flex;
gap: 10px;
max-width: 600px;
margin: 0 auto;
}
input {
flex: 1;
padding: 12px 20px;
border: 2px solid #e0e0e0;
border-radius: 25px;
font-size: 16px;
transition: all 0.3s ease;
}
input:focus {
outline: none;
border-color: #3498db;
box-shadow: 0 0 5px rgba(52, 152, 219, 0.3);
}
button {
padding: 12px 30px;
background: #3498db;
color: white;
border: none;
border-radius: 25px;
cursor: pointer;
font-size: 16px;
transition: background 0.3s ease;
}
button:hover {
background: #2980b9;
}
.results {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 1.5rem;
padding: 1rem;
}
.result {
background: white;
border-radius: 10px;
overflow: hidden;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease;
}
.result:hover {
transform: translateY(-5px);
}
.result img {
width: 100%;
height: 200px;
object-fit: cover;
}
.result-info {
padding: 1rem;
}
.similarity-score {
color: #2c3e50;
font-weight: 600;
}
#loading {
display: none;
text-align: center;
margin: 2rem 0;
font-size: 1.2rem;
color: #666;
}
</style>
</head>
<body>
<div class="container">
<div class="search-container">
<h1>SkyPilot Image Search</h1>
<div class="search-box">
<input type="text" id="searchInput" placeholder="Enter your search query..."
onkeypress="if(event.key === 'Enter') search()">
<button onclick="search()">Search</button>
</div>
</div>
<div id="loading">Searching...</div>
<div id="results" class="results"></div>
</div>
<script>
async function search() {
const searchInput = document.getElementById('searchInput');
const loading = document.getElementById('loading');
const resultsDiv = document.getElementById('results');
if (!searchInput.value.trim()) return;
loading.style.display = 'block';
resultsDiv.innerHTML = '';
try {
const response = await fetch('/search', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json'
},
body: JSON.stringify({
text: searchInput.value.trim(),
n_results: 12
})
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || 'Search failed');
}
const results = await response.json();
resultsDiv.innerHTML = results.map(result => `
<div class="result">
<img src="/image/${result.image_path}"
alt="Search result">
<div class="result-info">
<p class="similarity-score">
Similarity: ${(result.similarity * 100).toFixed(1)}%
</p>
</div>
</div>
`).join('');
} catch (error) {
resultsDiv.innerHTML = `
<p style="color: #e74c3c; text-align: center; width: 100%;">
Error: ${error.message}
</p>
`;
} finally {
loading.style.display = 'none';
}
}
</script>
</body>
</html>
"""
def main():
parser = argparse.ArgumentParser(
description='Serve Vector Database with FastAPI')
parser.add_argument('--host',
type=str,
default='0.0.0.0',
help='Host to serve on')
parser.add_argument('--port',
type=int,
default=8000,
help='Port to serve on')
parser.add_argument('--collection-name',
type=str,
default='clip_embeddings',
help='ChromaDB collection name')
parser.add_argument('--persist-dir',
type=str,
default='/vectordb/chroma',
help='Directory where ChromaDB is persisted')
parser.add_argument('--images-dir',
type=str,
default='/images',
help='Directory where images are stored')
parser.add_argument('--model-name',
type=str,
default='ViT-bigG-14',
help='CLIP model name')
args = parser.parse_args()
# Initialize global variables
global model, tokenizer, collection, device, images_dir
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f'Using device: {device}')
# Set images directory
images_dir = args.images_dir
# Load the model
import open_clip
model, _, _ = open_clip.create_model_and_transforms(
args.model_name, pretrained='laion2b_s39b_b160k', device=device)
tokenizer = open_clip.get_tokenizer(args.model_name)
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=args.persist_dir)
try:
# Get the collection
collection = client.get_collection(name=args.collection_name)
logger.info(f'Connected to collection: {args.collection_name}')
logger.info(f'Total documents in collection: {collection.count()}')
except ValueError as e:
logger.error(f'Error: {str(e)}')
logger.error(
'Make sure the collection exists and the persist_dir is correct.')
raise
# Start the server
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == '__main__':
main()
serve_vectordb.yaml
name: vectordb-serve
workdir: .
resources:
accelerators:
# ordered by pricing (cheapest to most expensive)
# skypilot will try to use the cheapest available accelerator
# serve requires a GPU to compute the embeddings
T4: 1
L4: 1
A10G: 1
A10: 1
V100: 1
memory: 32+
ports: 8000
use_spot: true
file_mounts:
/vectordb:
name: sky-vectordb
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
/images:
name: sky-demo-image
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
setup: |
pip install numpy==1.26.4
pip install torch==2.5.1 torchvision==0.20.1 ftfy regex tqdm
pip install open_clip_torch chromadb pandas
pip install fastapi uvicorn pydantic
run: |
python scripts/serve_vectordb.py \
--collection-name clip_embeddings \
--persist-dir /vectordb/chroma \
--images-dir /images \
--host 0.0.0.0 \
--port 8000
service:
replicas: 1
readiness_probe:
path: /health