Source: examples/redisvl-vector-search
RedisVL + SkyPilot: Vector Search at Scale#
Distributed vector search over 1M research papers using RedisVL and SkyPilot.
📖 Read the full blog post.
Features#
Distributed GPU-accelerated embedding generation
Vector search with RedisVL over 1M research papers
Automatic failover and retry with SkyPilot managed jobs
Direct streaming to Redis (no intermediate storage)
Cost-effective: ~$0.85 to embed entire 1M paper dataset (5 parallel T4 spot instances on GCP @ 80 mins)
Setup#
pip install -r requirements.txt
Create .env
file:
REDIS_HOST=your-redis-host.redislabs.com
REDIS_PORT=12345
REDIS_USER=default
REDIS_PASSWORD=your-password
Launch Jobs#
Batch launcher (recommended):
python batch_embedding_launcher.py --num-jobs 5 --env-file .env
Single managed job:
sky jobs launch embedding_job.yaml \
--env JOB_START_IDX=0 \
--env JOB_END_IDX=200000 \
--env-file .env
Monitor:
sky jobs queue
sky dashboard
Run Search API#
Option 1: Deploy with SkyPilot#
Deploy the search API and Streamlit UI to the cloud:
sky launch -c redisvl-search-api search_api.yaml --env-file .env
Access the services:
FastAPI:
export API_ENDPOINT=$(sky status --endpoint 8001 redisvl-search-api)
echo $API_ENDPOINT
curl -X POST "http://$API_ENDPOINT/search" \
-H "Content-Type: application/json" \
-d '{"query": "neural networks", "k": 5}'
API in action:
Streamlit UI:
export STREAMLIT_ENDPOINT=$(sky status --endpoint 8501 redisvl-search-api)
echo $STREAMLIT_ENDPOINT
Streamlit app interface:
Tear down when done:
sky down redisvl-search-api
Option 2: Run Locally#
source .env
python app.py
Test#
curl -X POST "http://localhost:8000/search" \
-H "Content-Type: application/json" \
-d '{"query": "neural networks", "k": 5}'
Files#
app.py
- FastAPI search servicebatch_embedding_launcher.py
- Job launchercompute_embeddings.py
- Embedding generationembedding_job.yaml
- SkyPilot job configstreamlit_app.py
- Search UI
Included files#
from contextlib import asynccontextmanager
from datetime import datetime
import logging
import os
from typing import Any, Dict, List, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pydantic import Field
import redis
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
redis_client = None
index = None
model = None
class SearchRequest(BaseModel):
query: str
k: int = Field(default=10, ge=1, le=100)
filters: Optional[Dict[str, Any]] = None
class PaperResult(BaseModel):
id: str
title: str
abstract: str
authors: str
venue: str
year: int
n_citation: int
score: float
class SearchResponse(BaseModel):
results: List[PaperResult]
total: int
time_ms: float
@asynccontextmanager
async def lifespan(app: FastAPI):
global redis_client, index, model
redis_client = redis.Redis(host=os.getenv("REDIS_HOST"),
port=int(os.getenv("REDIS_PORT")),
username=os.getenv("REDIS_USER"),
password=os.getenv("REDIS_PASSWORD"),
decode_responses=True)
redis_client.ping()
logger.info("Connected to Redis")
index = SearchIndex.from_yaml("config/redis_schema_papers.yaml",
redis_client=redis_client)
try:
index.create(overwrite=False)
except Exception as e:
if "already exists" not in str(e).lower():
raise
model = SentenceTransformer(os.getenv("MODEL_NAME"), trust_remote_code=True)
logger.info("Ready")
yield
if redis_client:
redis_client.close()
app = FastAPI(title="Paper Search API", lifespan=lifespan)
app.add_middleware(CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"])
@app.get("/health")
async def health():
try:
redis_client.ping()
return {"status": "healthy", "papers": index.info().get("num_docs", 0)}
except:
return {"status": "unhealthy"}
@app.post("/search", response_model=SearchResponse)
async def search(request: SearchRequest):
start = datetime.now()
embedding = model.encode(request.query, normalize_embeddings=True).tolist()
query = VectorQuery(vector=embedding,
vector_field_name="paper_embedding",
return_fields=[
"title", "abstract", "authors", "venue", "year",
"n_citation"
],
num_results=request.k)
if request.filters:
filters = []
for field, value in request.filters.items():
if field == "year" and isinstance(value, dict):
if "min" in value and "max" in value:
filters.append(f"@year:[{value['min']} {value['max']}]")
elif field == "venue":
filters.append(f"@venue:{{{value}}}")
if filters:
query.set_filter(" ".join(filters))
results = index.query(query)
papers = []
for r in results:
abstract = r.get('abstract', '')
if len(abstract) > 200:
abstract = abstract[:200] + "..."
papers.append(
PaperResult(id=r.get('id', ''),
title=r.get('title', ''),
abstract=abstract,
authors=r.get('authors', ''),
venue=r.get('venue', ''),
year=int(r.get('year', 0)),
n_citation=int(r.get('n_citation', 0)),
score=round(1 - float(r.get('vector_score', 0)), 3)))
time_ms = (datetime.now() - start).total_seconds() * 1000
return SearchResponse(results=papers,
total=len(papers),
time_ms=round(time_ms, 2))
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
batch_embedding_launcher.py
import argparse
import os
import sky
def calculate_job_range(total_records, job_rank, total_jobs):
chunk_size = total_records // total_jobs
remainder = total_records % total_jobs
job_start = job_rank * chunk_size + min(job_rank, remainder)
if job_rank < remainder:
chunk_size += 1
return job_start, job_start + chunk_size
def load_env_file(env_file):
env_vars = {}
if not os.path.exists(env_file):
return env_vars
with open(env_file) as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
if line.startswith('export '):
line = line[7:]
if '=' in line:
key, value = line.split('=', 1)
env_vars[key] = value.strip('"').strip("'")
return env_vars
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--total-records', type=int, default=1000000)
parser.add_argument('--num-jobs', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=10240)
parser.add_argument('--env-file', type=str, default='.env')
args = parser.parse_args()
env_vars = load_env_file(args.env_file)
required = ['REDIS_HOST', 'REDIS_PORT', 'REDIS_USER', 'REDIS_PASSWORD']
missing = [var for var in required if var not in env_vars]
if missing:
raise ValueError(f"Missing required environment variables: {missing}")
task = sky.Task.from_yaml('embedding_job.yaml')
for i in range(args.num_jobs):
start, end = calculate_job_range(args.total_records, i, args.num_jobs)
task_envs = task.update_envs({
'JOB_START_IDX': str(start),
'JOB_END_IDX': str(end),
'BATCH_SIZE': str(args.batch_size),
'REDIS_HOST': env_vars['REDIS_HOST'],
'REDIS_PORT': env_vars['REDIS_PORT'],
'REDIS_USER': env_vars['REDIS_USER'],
'REDIS_PASSWORD': env_vars['REDIS_PASSWORD']
})
sky.jobs.launch(task_envs, name=f'embeddings-{start}-{end}')
print(f"Launched job for records {start}-{end}")
print(f"\n{args.num_jobs} jobs launched successfully!")
print("Monitor with: sky jobs queue")
if __name__ == '__main__':
main()
compute_embeddings.py
import argparse
import logging
from queue import Queue
from threading import Thread
import numpy as np
import pandas as pd
import redis
from redisvl.index import SearchIndex
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', required=True)
parser.add_argument('--schema-file', required=True)
parser.add_argument('--start-idx', type=int, required=True)
parser.add_argument('--end-idx', type=int, required=True)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--model-name',
default='nomic-ai/nomic-embed-text-v2-moe')
parser.add_argument('--redis-host', required=True)
parser.add_argument('--redis-port', type=int, required=True)
parser.add_argument('--redis-user', required=True)
parser.add_argument('--redis-password', required=True)
return parser.parse_args()
def create_redis_client(host, port, username, password):
client = redis.Redis(host=host,
port=port,
username=username,
password=password,
decode_responses=True)
client.ping()
return client
def init_index(schema_file, redis_client):
index = SearchIndex.from_yaml(schema_file, redis_client=redis_client)
try:
index.create(overwrite=False)
logger.info("Created new Redis index")
except Exception as e:
if "already exists" not in str(e).lower():
raise
logger.info("Using existing Redis index")
return index
def process_paper(row, embedding):
return {
'id': f"paper:{row['id']}",
'title': row['title'],
'abstract': row['abstract'],
'authors': row['authors'],
'venue': row['venue'],
'year': safe_int(row['year'], 2000),
'n_citation': safe_int(row['n_citation'], 0),
'paper_embedding': np.array(embedding, dtype=np.float32).tobytes()
}
def safe_int(value, default):
if pd.notna(value) and str(value).isdigit():
return int(value)
return default
def redis_writer(queue, index):
while True:
batch = queue.get()
if batch is None:
break
try:
index.load(batch, id_field='id')
logger.info(f"Streamed {len(batch)} papers to Redis")
except Exception as e:
logger.error(f"Failed to load batch: {e}")
queue.task_done()
def main():
args = parse_args()
df = pd.read_csv(args.input_file,
encoding='utf-8',
encoding_errors='replace')
total_records = len(df)
if args.start_idx >= total_records:
logger.warning(
f"Start index {args.start_idx} >= total records {total_records}")
return
end_idx = min(args.end_idx, total_records)
logger.info(
f"Processing records {args.start_idx}-{end_idx} of {total_records}")
redis_client = create_redis_client(args.redis_host, args.redis_port,
args.redis_user, args.redis_password)
index = init_index(args.schema_file, redis_client)
redis_queue = Queue(maxsize=10)
writer_thread = Thread(target=redis_writer, args=(redis_queue, index))
writer_thread.start()
model = SentenceTransformer(args.model_name, trust_remote_code=True)
df_partition = df.iloc[args.start_idx:end_idx]
batch_docs = []
texts_batch = []
rows_batch = []
with tqdm(total=len(df_partition),
desc=f"Records {args.start_idx}-{end_idx}") as pbar:
for _, row in df_partition.iterrows():
texts_batch.append(f"{row['title']} {row['abstract']}")
rows_batch.append(row)
if len(texts_batch) >= args.batch_size:
embeddings = model.encode(texts_batch,
normalize_embeddings=True)
for row, embedding in zip(rows_batch, embeddings):
batch_docs.append(process_paper(row, embedding))
texts_batch = []
rows_batch = []
pbar.update(args.batch_size)
if len(batch_docs) >= args.batch_size:
redis_queue.put(batch_docs.copy())
batch_docs = []
if texts_batch:
embeddings = model.encode(texts_batch, normalize_embeddings=True)
for row, embedding in zip(rows_batch, embeddings):
batch_docs.append(process_paper(row, embedding))
pbar.update(len(texts_batch))
if batch_docs:
redis_queue.put(batch_docs.copy())
redis_queue.join()
redis_queue.put(None)
writer_thread.join()
logger.info(f"Completed records {args.start_idx}-{end_idx}")
if __name__ == "__main__":
main()
config/redis_schema_papers.yaml
version: "0.1.0"
index:
name: paper_search
prefix: "paper:"
fields:
- name: paper_embedding
type: vector
attrs:
algorithm: HNSW
dims: 768
distance_metric: COSINE
initial_cap: 950
ef_construction: 200
m: 16
- name: title
type: text
attrs:
weight: 2.0
- name: abstract
type: text
attrs:
weight: 1.5
- name: authors
type: text
- name: venue
type: tag
- name: year
type: numeric
attrs:
sortable: true
- name: n_citation
type: numeric
attrs:
sortable: true
embedding_job.yaml
name: embeddings-job
resources:
accelerators:
T4: 1 # Cheapest option that works
L4: 1 # Backup if T4 unavailable
A100: 1 # Last resort
any_of:
- use_spot: true # Try spot first
- use_spot: false # Fallback to on-demand
num_nodes: 1
envs:
# These env vars are required but should be passed in at launch time
JOB_START_IDX: ''
JOB_END_IDX: ''
BATCH_SIZE: 10240
MODEL_NAME: "nomic-ai/nomic-embed-text-v2-moe"
REDIS_HOST: ''
REDIS_PORT: ''
REDIS_USER: ''
REDIS_PASSWORD: ''
workdir: .
setup: |
uv pip install -r requirements.txt
sudo apt install unzip -y
# Download and unzip the dataset if not already present
if [ ! -f /data/dblp-v10.csv ]; then
echo "Downloading research papers dataset..."
mkdir -p data
curl -L -o data/research-papers-dataset.zip https://www.kaggle.com/api/v1/datasets/download/nechbamohammed/research-papers-dataset
echo "Unzipping dataset..."
unzip -j data/research-papers-dataset.zip -d data/
rm data/research-papers-dataset.zip
echo "Dataset ready at data/dblp-v10.csv"
else
echo "Dataset already exists at data/dblp-v10.csv"
fi
run: |
echo "Processing papers: Records ${JOB_START_IDX}-${JOB_END_IDX}"
python compute_embeddings.py \
--input-file data/dblp-v10.csv \
--schema-file config/redis_schema_papers.yaml \
--start-idx ${JOB_START_IDX} \
--end-idx ${JOB_END_IDX} \
--model-name ${MODEL_NAME} \
--batch-size ${BATCH_SIZE} \
--redis-host ${REDIS_HOST} \
--redis-port ${REDIS_PORT} \
--redis-user ${REDIS_USER} \
--redis-password ${REDIS_PASSWORD}
echo "Job completed: Records ${JOB_START_IDX}-${JOB_END_IDX}"
requirements.txt
redisvl>=0.8.1
redis>=6.4.0
skypilot-nightly[aws,gcp]>=1.0.0.dev20250814
fastapi>=0.116.1
uvicorn>=0.35.0
pydantic>=2.11.7
streamlit>=1.48.1
torch>=2.2.2
transformers>=4.55.0
numpy==1.26.0
sentence-transformers>=5.1.0
pandas>=2.3.1
tqdm>=4.67.1
requests>=2.32.4
einops>=0.8.1
search_api.yaml
name: redisvl-search-api
resources:
infra: gcp
cpus: 4+
memory: 8+
ports:
- 8001 # FastAPI
- 8501 # Streamlit
envs:
REDIS_HOST:
REDIS_PORT:
REDIS_USER:
REDIS_PASSWORD:
MODEL_NAME: "nomic-ai/nomic-embed-text-v2-moe"
workdir: .
setup: |
uv pip install -r requirements.txt
# Pre-download the embedding model
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('${MODEL_NAME}', trust_remote_code=True)"
run: |
# Start both FastAPI and Streamlit in background
uvicorn app:app --host 0.0.0.0 --port 8001 --log-level info &
streamlit run streamlit_app.py --server.port 8501 --server.address 0.0.0.0 &
wait
streamlit_app.py
import json
import requests
import streamlit as st
st.set_page_config(page_title="Research Paper Search", layout="wide")
API_BASE = "http://localhost:8001"
st.title("Research Paper Search")
st.write("Search through 1M research papers using semantic similarity")
with st.sidebar:
st.header("Filters")
num_results = st.slider("Results", 1, 50, 10)
year_filter = st.checkbox("Filter by year")
if year_filter:
min_year = st.number_input("From", 1900, 2024, 2000)
max_year = st.number_input("To", 1900, 2024, 2024)
venue_filter = st.text_input("Venue", placeholder="exact match")
st.divider()
try:
resp = requests.get(f"{API_BASE}/health", timeout=3)
if resp.status_code == 200:
data = resp.json()
if data["status"] == "healthy":
st.success("System online")
papers = data.get('papers', 0)
if papers:
st.metric("Papers", f"{papers:,}")
else:
st.error("System offline")
else:
st.error("API error")
except:
st.error("Connection failed")
def search_papers(query):
if not query:
return
filters = {}
if year_filter:
filters["year"] = {"min": min_year, "max": max_year}
if venue_filter.strip():
filters["venue"] = venue_filter.strip()
search_data = {
"query": query,
"k": num_results,
"filters": filters if filters else None
}
with st.spinner("Searching..."):
try:
response = requests.post(
f"{API_BASE}/search",
headers={"Content-Type": "application/json"},
data=json.dumps(search_data),
timeout=30)
if response.status_code == 200:
results = response.json()
st.info(
f"Found {results['total']} results ({results['time_ms']:.0f}ms)"
)
if results.get("results"):
for i, paper in enumerate(results["results"], 1):
with st.container():
col1, col2 = st.columns([3, 1])
with col1:
st.subheader(f"{i}. {paper['title']}")
st.write(f"**Authors:** {paper['authors']}")
st.write(paper['abstract'])
with col2:
st.metric("Score", f"{paper['score']:.3f}")
st.write(f"Year: {paper['year']}")
st.write(f"Venue: {paper['venue']}")
st.write(f"Citations: {paper['n_citation']:,}")
st.markdown("---")
else:
st.warning("No results found")
else:
st.error(f"Search failed: {response.status_code}")
except requests.exceptions.Timeout:
st.error("Search timeout")
except Exception as e:
st.error(f"Error: {str(e)}")
with st.form("search_form"):
search_query = st.text_input(
"Search",
placeholder="neural networks, machine learning...",
label_visibility="collapsed")
search_submitted = st.form_submit_button("Search", type="primary")
if search_submitted:
search_papers(search_query)
if not search_query:
st.write("**Try searching for:**")
queries = [
"neural networks", "machine learning", "computer vision",
"natural language processing", "deep learning", "reinforcement learning"
]
cols = st.columns(3)
for i, q in enumerate(queries):
with cols[i % 3]:
if st.button(q, key=f"ex_{i}"):
search_papers(q)
st.divider()
st.caption("Built with RedisVL + SkyPilot | Data from Kaggle Research Papers")