Source: llm/rag
Retrieval Augmented Generation with DeepSeek R1#
For the full blog post, please find it here.
Large-Scale Legal Document Search and Analysis#
As legal document collections grow, traditional keyword search becomes insufficient for finding relevant information. Retrieval-Augmented Generation (RAG) combines the power of vector search with large language models to enable semantic search and intelligent answers.
In particular:
Accuracy: RAG systems ground their responses in source documents, reducing hallucination and improving answer reliability.
Context-Awareness: By retrieving relevant documents before generating answers, the system provides responses that consider specific legal contexts.
Traceability: All generated answers can be traced back to source documents, crucial for legal applications.
SkyPilot streamlines the development and deployment of RAG systems in any cloud or kubernetes by managing infrastructure and enabling efficient, cost-effective compute resource usage.
In this example, we use legal documents by pile of law as example data to demonstrate RAG capabilities. The system processes a collection of legal texts, including case law, statutes, and legal discussions, to enable semantic search and intelligent question answering. This approach can help legal professionals quickly find relevant precedents, analyze complex legal scenarios, and extract insights from large document collections.
We use Alibaba-NLP/gte-Qwen2-7B-instruct for generating document embeddings and distilled Deepseek R1 (deepseek-ai/DeepSeek-R1-Distill-Llama-8B) for generating final anwsers.
Why SkyPilot: SkyPilot streamlines the process of running such large-scale jobs in the cloud. It abstracts away much of the complexity of managing infrastructure and helps you run compute-intensive tasks efficiently and cost-effectively through managed jobs.
Step 0: Set Up The Environment#
Install the following Prerequisites:
SkyPilot: Ensure SkyPilot is installed and
sky check
succeeds. See installation instructions.
Set up bucket names for storing embeddings and vector database:
export EMBEDDINGS_BUCKET_NAME=sky-rag-embeddings
export VECTORDB_BUCKET_NAME=sky-rag-vectordb
Note that these bucket names need to be unique to the entire SkyPilot community.
Step 1: Compute Embeddings from Legal Documents#
Convert legal documents into vector representations using Alibaba-NLP/gte-Qwen2-7B-instruct
. These embeddings enable semantic search across the document collection.
Launch the embedding computation:
python3 batch_compute_embeddings.py --embedding_bucket_name $EMBEDDINGS_BUCKET_NAME
Here is how the python script launches vLLM with Alibaba-NLP/gte-Qwen2-7B-instruct
for embedding generation, where we set each worker to work from START_IDX
to END_IDX
.
SkyPilot YAML for embedding generation
name: compute-legal-embeddings
resources:
accelerators: {L4:1, A100:1}
envs:
START_IDX: 0 # Will be overridden by batch_compute_vectors.py
END_IDX: 10000 # Will be overridden by batch_compute_vectors.py
MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"
EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings # Bucket name for storing embeddings
file_mounts:
/output:
name: ${EMBEDDINGS_BUCKET_NAME}
mode: MOUNT
setup: |
pip install torch==2.5.1 vllm==0.6.6.post1
...
envs:
MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"
run: |
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--model $MODEL_NAME \
--max-model-len 3072 \
--task embed &
python scripts/compute_embeddings.py \
--start-idx $START_IDX \
--end-idx $END_IDX \
--chunk-size 2048 \
--chunk-overlap 512 \
--vllm-endpoint http://localhost:8000
This automatically launches 10 SkyPilot managed jobs on L4 GPUs to processe documents from the Pile of Law dataset and computes embeddings in batches:
Processing documents: 100%|██████████| 1000/1000 [00:45<00:00, 22.05it/s]
Saving embeddings to: embeddings_0_1000.parquet
...
We leverage SkyPilot’s managed jobs feature to enable parallel processing across multiple regions and cloud providers.
SkyPilot handles job state management and automatic recovery from failures when using spot instances.
Managed jobs are cost-efficient and streamline the processing of the partitioned dataset.
You can check all the jobs by running sky jobs dashboard
.
Step 2: Build RAG with Vector Database#
After computing embeddings, construct a ChromaDB vector database for efficient similarity search:
sky launch build_rag.yaml --env EMBEDDINGS_BUCKET_NAME=$EMBEDDINGS_BUCKET_NAME --env VECTORDB_BUCKET_NAME=$VECTORDB_BUCKET_NAME
The process builds the database in batches:
Loading embeddings from: embeddings_0_1000.parquet
Adding vectors to ChromaDB: 100%|██████████| 1000/1000 [00:12<00:00, 81.97it/s]
...
Step 3: Serve the RAG#
Deploy the RAG service to handle queries and generate answers:
sky launch -c legal-rag serve_rag.yaml --env VECTORDB_BUCKET_NAME=$VECTORDB_BUCKET_NAME
Or use Sky Serve for managed deployment:
sky serve up -n legal-rag serve_rag.yaml --env VECTORDB_BUCKET_NAME=$VECTORDB_BUCKET_NAME
To query the system, get the endpoint:
sky serve status legal-rag --endpoint
You can visit the website and input your query there! A few queries to try out:
I want to break my lease. my landlord doesn’t allow me to do that. My employer has not provided the final paycheck after termination.
Disclaimer#
This document provides instruction for building a RAG system with SkyPilot. The system and its outputs should not be considered as legal advice. Please consult qualified legal professionals for any legal matters.
Included files#
batch_compute_embeddings.py
"""
Use skypilot to launch managed jobs that will run the embedding calculation for RAG.
This script is responsible for splitting the input dataset up among several workers,
then using skypilot to launch managed jobs for each worker. We use compute_embeddings.yaml
to define the managed job info.
"""
#!/usr/bin/env python3
import argparse
import sky
def calculate_job_range(start_idx: int, end_idx: int, job_rank: int,
total_jobs: int) -> tuple[int, int]:
"""Calculate the range of indices this job should process.
Args:
start_idx: Global start index
end_idx: Global end index
job_rank: Current job's rank (0-based)
total_jobs: Total number of jobs
Returns:
Tuple of [job_start_idx, job_end_idx)
"""
total_range = end_idx - start_idx
chunk_size = total_range // total_jobs
remainder = total_range % total_jobs
# Distribute remainder across first few jobs
job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder)
if job_rank < remainder:
chunk_size += 1
job_end = job_start + chunk_size
return job_start, job_end
def main():
parser = argparse.ArgumentParser(
description='Launch batch RAG embedding computation jobs')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Global start index in dataset')
parser.add_argument(
'--end-idx',
type=int,
# this is the last index of the reddit post dataset
default=109740,
help='Global end index in dataset, not inclusive')
parser.add_argument('--num-jobs',
type=int,
default=1,
help='Number of jobs to partition the work across')
parser.add_argument("--embedding_bucket_name",
type=str,
default="sky-rag-embeddings",
help="Name of the bucket to store embeddings")
args = parser.parse_args()
# Load the task template
task = sky.Task.from_yaml('compute_embeddings.yaml')
# Launch jobs for each partition
for job_rank in range(args.num_jobs):
# Calculate index range for this job
job_start, job_end = calculate_job_range(args.start_idx, args.end_idx,
job_rank, args.num_jobs)
# Update environment variables for this job
task_copy = task.update_envs({
'START_IDX': job_start,
'END_IDX': job_end,
'EMBEDDINGS_BUCKET_NAME': args.embedding_bucket_name,
})
sky.jobs.launch(task_copy, name=f'rag-compute-{job_start}-{job_end}')
if __name__ == '__main__':
main()
build_rag.yaml
name: build-legal-rag
workdir: .
resources:
memory: 32+ # Need more memory for merging embeddings
cloud: aws
envs:
EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings
VECTORDB_BUCKET_NAME: sky-rag-vectordb
file_mounts:
/embeddings:
name: ${EMBEDDINGS_BUCKET_NAME}
# this needs to be the same as the output in compute_embeddings.yaml
mode: MOUNT
/vectordb:
name: ${VECTORDB_BUCKET_NAME}
mode: MOUNT
setup: |
pip install chromadb pandas tqdm pyarrow
run: |
python scripts/build_rag.py \
--collection-name legal_docs \
--persist-dir /vectordb/chroma \
--embeddings-dir /embeddings \
--batch-size 1000
compute_embeddings.yaml
name: compute-law-embeddings
workdir: .
resources:
accelerators:
L4: 1
memory: 32+
any_of:
- use_spot: true
- use_spot: false
envs:
START_IDX: 0 # Will be overridden by batch_compute_vectors.py
END_IDX: 10000 # Will be overridden by batch_compute_vectors.py
MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"
EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings # Bucket name for storing embeddings
file_mounts:
/output:
name: ${EMBEDDINGS_BUCKET_NAME}
mode: MOUNT
setup: |
# Install dependencies for vLLM
pip install transformers==4.48.1 vllm==0.6.6.post1
# Install dependencies for embedding computation
pip install numpy pandas requests tqdm datasets
pip install nltk hf_transfer
run: |
# Initialize and download the model
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/model $MODEL_NAME
# Start vLLM service in background
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--model /tmp/model \
--max-model-len 3072 \
--task embed &
# Wait for vLLM to be ready by checking the health endpoint
echo "Waiting for vLLM service to be ready..."
while ! curl -s http://localhost:8000/health > /dev/null; do
sleep 5
echo "Still waiting for vLLM service..."
done
echo "vLLM service is ready!"
# Process the assigned range of documents
echo "Processing documents from $START_IDX to $END_IDX"
python scripts/compute_embeddings.py \
--output-path "/output/embeddings_${START_IDX}_${END_IDX}.parquet" \
--start-idx $START_IDX \
--end-idx $END_IDX \
--chunk-size 2048 \
--chunk-overlap 512 \
--vllm-endpoint http://localhost:8000 \
--batch-size 32
# Clean up vLLM service
pkill -f "python -m vllm.entrypoints.openai.api_server"
echo "vLLM service has been stopped"
scripts/build_rag.py
"""
This script is responsible for building the vector database from the mounted bucket and saving
it to another mounted bucket.
"""
import argparse
import base64
from concurrent.futures import as_completed
from concurrent.futures import ProcessPoolExecutor
import glob
import logging
import multiprocessing
import os
import pickle
import shutil
import tempfile
import chromadb
import numpy as np
import pandas as pd
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def list_local_parquet_files(mount_path: str, prefix: str) -> list:
"""List all parquet files in the mounted directory."""
search_path = os.path.join(mount_path, prefix, '**/*.parquet')
parquet_files = glob.glob(search_path, recursive=True)
return parquet_files
def process_parquet_file(args):
"""Process a single parquet file and return the processed data."""
parquet_file, batch_size = args
try:
results = []
df = pd.read_parquet(parquet_file)
# Process in batches
for i in range(0, len(df), batch_size):
batch_df = df.iloc[i:i + batch_size]
# Extract data from DataFrame
ids = [str(idx) for idx in batch_df['id']]
embeddings = [pickle.loads(emb) for emb in batch_df['embedding']]
documents = batch_df['content'].tolist(
) # Content goes to documents
# Create metadata from the available fields (excluding content)
metadatas = [{
'name': row['name'],
'split': row['split'],
'source': row['source'],
} for _, row in batch_df.iterrows()]
results.append((ids, embeddings, documents, metadatas))
return results
except Exception as e:
logger.error(f'Error processing file {parquet_file}: {str(e)}')
return None
def main():
parser = argparse.ArgumentParser(
description='Build ChromaDB from mounted parquet files')
parser.add_argument('--collection-name',
type=str,
default='rag_embeddings',
help='ChromaDB collection name')
parser.add_argument('--persist-dir',
type=str,
default='/vectordb/chroma',
help='Directory to persist ChromaDB')
parser.add_argument(
'--batch-size',
type=int,
default=1000,
help='Batch size for processing, this needs to fit in memory')
parser.add_argument('--embeddings-dir',
type=str,
default='/embeddings',
help='Path to mounted bucket containing parquet files')
parser.add_argument(
'--prefix',
type=str,
default='',
help='Prefix path within mounted bucket to search for parquet files')
args = parser.parse_args()
# Create a temporary directory for building the database. The
# mounted bucket does not support append operation, so build in
# the tmpdir and then copy it to the final location.
with tempfile.TemporaryDirectory() as temp_dir:
logger.info(f'Using temporary directory: {temp_dir}')
# Initialize ChromaDB in temporary directory
client = chromadb.PersistentClient(path=temp_dir)
# Create or get collection for chromadb
# it attempts to create a collection with the same name
# if it already exists, it will get the collection
try:
collection = client.create_collection(
name=args.collection_name,
metadata={'description': 'RAG embeddings from legal documents'})
logger.info(f'Created new collection: {args.collection_name}')
except ValueError:
collection = client.get_collection(name=args.collection_name)
logger.info(f'Using existing collection: {args.collection_name}')
# List parquet files from mounted directory
parquet_files = list_local_parquet_files(args.embeddings_dir,
args.prefix)
logger.info(f'Found {len(parquet_files)} parquet files')
# Process files in parallel
max_workers = max(1,
multiprocessing.cpu_count() - 1) # Leave one CPU free
logger.info(f'Processing files using {max_workers} workers')
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all files for processing
future_to_file = {
executor.submit(process_parquet_file, (file, args.batch_size)):
file for file in parquet_files
}
# Process results as they complete
for future in tqdm(as_completed(future_to_file),
total=len(parquet_files),
desc='Processing files'):
file = future_to_file[future]
try:
results = future.result()
if results:
for ids, embeddings, documents, metadatas in results:
collection.add(ids=list(ids),
embeddings=list(embeddings),
documents=list(documents),
metadatas=list(metadatas))
except Exception as e:
logger.error(f'Error processing file {file}: {str(e)}')
continue
logger.info('Vector database build complete!')
logger.info(f'Total documents in collection: {collection.count()}')
# Copy the completed database to the final location
logger.info(f'Copying database to final location: {args.persist_dir}')
if os.path.exists(args.persist_dir):
logger.info('Removing existing database directory')
shutil.rmtree(args.persist_dir)
shutil.copytree(temp_dir, args.persist_dir)
logger.info('Database copy complete!')
if __name__ == '__main__':
main()
scripts/compute_embeddings.py
"""
Script to compute embeddings for Pile of Law dataset using `Alibaba-NLP/gte-Qwen2-7B-instruct` through vLLM.
"""
import argparse
import logging
import os
from pathlib import Path
import pickle
import shutil
import time
from typing import Dict, List, Tuple
from datasets import load_dataset
import nltk
import numpy as np
import pandas as pd
import requests
from tqdm import tqdm
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize NLTK to chunk documents
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
logger.info('Downloading NLTK punkt tokenizer...')
nltk.download('punkt')
nltk.download('punkt_tab')
logger.info('Download complete')
def load_law_documents(start_idx: int = 0, end_idx: int = 1000) -> List[Dict]:
"""Load documents from Pile of Law dataset.
Args:
start_idx: Starting index in dataset
end_idx: Ending index in dataset
Returns:
List of documents
"""
dataset = load_dataset('pile-of-law/pile-of-law',
'all',
split='train',
streaming=True,
trust_remote_code=True)
documents = []
for idx, doc in enumerate(
dataset.skip(start_idx).take(end_idx - start_idx)):
documents.append({
'id': f"{idx + start_idx}",
'name': doc['url'],
'text': doc['text'],
'split': 'train',
'source': 'r_legaladvice',
'created_timestamp': doc['created_timestamp'],
'downloaded_timestamp': doc['downloaded_timestamp'],
'url': doc['url']
})
if (idx + 1) % 100 == 0:
logger.info(f'Loaded {idx + 1} documents')
return documents
def chunk_document(document,
chunk_size=512,
overlap=50,
start_chunk_idx=0) -> Tuple[List[Dict], int]:
"""Split document into overlapping chunks using sentence-aware splitting.
Args:
document: The document to chunk
chunk_size: Maximum size of each chunk in characters
overlap: Number of characters to overlap between chunks
start_chunk_idx: Starting index for global chunk counting
Returns:
List of chunks and the next available chunk index
"""
text = document['text']
chunks = []
chunk_idx = start_chunk_idx
# Split into sentences first
sentences = nltk.sent_tokenize(text)
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_len = len(sentence)
# If adding this sentence would exceed chunk size, save current chunk
if current_length + sentence_len > chunk_size and current_chunk:
chunk_text = ' '.join(current_chunk)
chunks.append({
'id': document['id'] + '_' + str(chunk_idx),
'name': document['name'],
'content': document['text'], # Store full document content
'chunk_text': chunk_text.strip(), # Store the specific chunk
'chunk_start': len(' '.join(
current_chunk[:-(2 if overlap > 0 else 0)])) if overlap > 0
else 0, # Approximate start position
'split': document['split'],
'source': document['source'],
'document_id': document['id'],
'document_url': document['url'],
'document_created_timestamp': document['created_timestamp'],
'document_downloaded_timestamp':
document['downloaded_timestamp']
})
chunk_idx += 1
# Keep last few sentences for overlap
overlap_text = ' '.join(current_chunk[-2:]) # Keep last 2 sentences
current_chunk = [overlap_text] if overlap > 0 else []
current_length = len(overlap_text) if overlap > 0 else 0
current_chunk.append(sentence)
current_length += sentence_len + 1 # +1 for space
# Add the last chunk if it's not empty
if current_chunk:
chunk_text = ' '.join(current_chunk)
chunks.append({
'id': document['id'] + '_' + str(chunk_idx),
'name': document['name'],
'content': document['text'], # Store full document content
'chunk_text': chunk_text.strip(), # Store the specific chunk
'chunk_start': len(' '.join(
current_chunk[:-(2 if overlap > 0 else 0)]))
if overlap > 0 else 0, # Approximate start position
'split': document['split'],
'source': document['source'],
'document_id': document['id'],
'document_url': document['url'],
'document_created_timestamp': document['created_timestamp'],
'document_downloaded_timestamp': document['downloaded_timestamp']
})
chunk_idx += 1
return chunks, chunk_idx
def compute_embeddings_batch(chunks: List[Dict],
vllm_endpoint: str,
output_path: str,
batch_size: int = 32,
partition_size: int = 1000) -> None:
"""Compute embeddings for document chunks using DeepSeek R1 and save in partitions.
Args:
chunks: List of document chunks
vllm_endpoint: Endpoint for vLLM service
output_path: Path to save embeddings
"""
current_partition = []
partition_counter = 0
# Process in batches
for i in tqdm(range(0, len(chunks), batch_size),
desc='Computing embeddings'):
batch = chunks[i:i + batch_size]
# Create prompt for each chunk - simplified prompt
prompts = [chunk['content'] for chunk in batch]
try:
# Print request payload for debugging
request_payload = {
"model": "/tmp/model",
# because this is loaded from the mounted directory
"input": prompts
}
response = requests.post(f"{vllm_endpoint}/v1/embeddings",
json=request_payload,
timeout=60)
response.raise_for_status()
# Extract embeddings - updated response parsing
result = response.json()
if 'data' not in result:
raise ValueError(f"Unexpected response format: {result}")
embeddings = [item['embedding'] for item in result['data']]
# Combine embeddings with metadata
for chunk, embedding in zip(batch, embeddings):
current_partition.append({
'id': chunk['id'],
'name': chunk['name'],
'content': chunk['content'],
'chunk_text': chunk['chunk_text'],
'chunk_start': chunk['chunk_start'],
'split': chunk['split'],
'source': chunk['source'],
'embedding': pickle.dumps(np.array(embedding)),
# Include document metadata
'document_id': chunk['document_id'],
'document_url': chunk['document_url'],
'document_created_timestamp':
chunk['document_created_timestamp'],
'document_downloaded_timestamp':
chunk['document_downloaded_timestamp']
})
# Save partition when it reaches the desired size
if len(current_partition) >= partition_size:
save_partition(current_partition, output_path,
partition_counter)
partition_counter += 1
current_partition = []
except Exception as e:
logger.error(f"Error computing embeddings for batch: {str(e)}")
time.sleep(5)
continue
# Save any remaining embeddings in the final partition
if current_partition:
save_partition(current_partition, output_path, partition_counter)
def save_partition(results: List[Dict], output_path: str,
partition_counter: int) -> None:
"""Save a partition of embeddings to a parquet file with atomic write.
Args:
results: List of embeddings
output_path: Path to save embeddings
partition_counter: Partition counter
"""
if not results:
return
df = pd.DataFrame(results)
final_path = f'{output_path}_part_{partition_counter}.parquet'
temp_path = f'/tmp/embeddings_{partition_counter}.tmp'
# Write to temporary file first
df.to_parquet(temp_path, engine='pyarrow', index=False)
# Copy from temp to final destination
os.makedirs(os.path.dirname(final_path), exist_ok=True)
shutil.copy2(temp_path, final_path)
os.remove(temp_path) # Clean up temp file
logger.info(
f'Saved partition {partition_counter} to {final_path} with {len(df)} rows'
)
def main():
parser = argparse.ArgumentParser(
description='Compute embeddings for Pile of Law dataset')
parser.add_argument('--output-path',
type=str,
required=True,
help='Path to save embeddings parquet file')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Starting index in dataset')
parser.add_argument('--end-idx',
type=int,
default=1000,
help='Ending index in dataset')
parser.add_argument('--chunk-size',
type=int,
default=512,
help='Size of document chunks')
parser.add_argument('--chunk-overlap',
type=int,
default=50,
help='Overlap between chunks')
parser.add_argument('--vllm-endpoint',
type=str,
required=True,
help='Endpoint for vLLM service')
parser.add_argument('--batch-size',
type=int,
default=32,
help='Batch size for computing embeddings')
parser.add_argument('--partition-size',
type=int,
default=1000,
help='Number of embeddings per partition file')
args = parser.parse_args()
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# Load documents
logger.info('Loading documents from Pile of Law dataset...')
documents = load_law_documents(args.start_idx, args.end_idx)
logger.info(f'Loaded {len(documents)} documents')
# Chunk documents with global counter
logger.info('Chunking documents...')
chunks = []
next_chunk_idx = 0 # Initialize global chunk counter
for doc in documents:
doc_chunks, next_chunk_idx = chunk_document(doc, args.chunk_size,
args.chunk_overlap,
next_chunk_idx)
chunks.extend(doc_chunks)
logger.info(f'Created {len(chunks)} chunks')
# Compute embeddings and save in partitions
logger.info('Computing embeddings...')
compute_embeddings_batch(chunks, args.vllm_endpoint, args.output_path,
args.batch_size, args.partition_size)
logger.info('Finished computing and saving embeddings')
if __name__ == '__main__':
main()
scripts/serve_rag.py
"""
Script to serve RAG system combining vector search with DeepSeek R1.
"""
import argparse
import logging
import os
import pickle
import time
from typing import Any, Dict, List, Optional
import uuid
import chromadb
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import HTMLResponse
import numpy as np
from pydantic import BaseModel
import requests
import torch
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title='RAG System with DeepSeek R1')
# Global variables
collection = None
generator_endpoint = None # For text generation
embed_endpoint = None # For embeddings
# Dictionary to store in-progress LLM queries
active_requests = {}
class QueryRequest(BaseModel):
query: str
n_results: Optional[int] = 3
temperature: Optional[float] = 0.7
class DocumentsOnlyRequest(BaseModel):
query: str
n_results: Optional[int] = 3
class StartLLMRequest(BaseModel):
request_id: str
temperature: Optional[float] = 0.7
class SearchResult(BaseModel):
content: str
name: str
split: str
source: str
similarity: float
class RAGResponse(BaseModel):
answer: str
sources: List[SearchResult]
thinking_process: str # Add thinking process to response
class DocumentsOnlyResponse(BaseModel):
sources: List[SearchResult]
request_id: str
class LLMStatusResponse(BaseModel):
status: str # "pending", "completed", "error"
answer: Optional[str] = None
thinking_process: Optional[str] = None
error: Optional[str] = None
def encode_query(query: str) -> np.ndarray:
"""Encode query text using vLLM embeddings endpoint."""
global embed_endpoint
try:
response = requests.post(f"{embed_endpoint}/v1/embeddings",
json={
"model": "/tmp/embedding_model",
"input": [query]
},
timeout=30)
response.raise_for_status()
result = response.json()
if 'data' not in result:
raise ValueError(f"Unexpected response format: {result}")
return np.array(result['data'][0]['embedding'])
except Exception as e:
logger.error(f"Error computing query embedding: {str(e)}")
raise HTTPException(status_code=500,
detail="Error computing query embedding")
def query_collection(query_embedding: np.ndarray,
n_results: int = 10) -> List[SearchResult]:
"""Query the collection and return top matches."""
global collection
# Request more results initially to account for potential duplicates
max_results = min(n_results * 2, 20) # Get more results but cap at 20
results = collection.query(query_embeddings=[query_embedding.tolist()],
n_results=max_results,
include=['metadatas', 'distances', 'documents'])
# Get results
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
# Convert distances to similarities
similarities = [1 - (d / 2) for d in distances]
# Create a set to track unique content
seen_content = set()
unique_results = []
for doc, meta, sim in zip(documents, metadatas, similarities):
# Use content as the uniqueness key
if doc not in seen_content:
seen_content.add(doc)
logger.info(f"Found {meta} with similarity {sim}")
logger.info(f"Content: {doc}")
unique_results.append((doc, meta, sim))
# Break if we have enough unique results
if len(unique_results) >= n_results:
break
return [
SearchResult(content=doc,
name=meta['name'],
split=meta['split'],
source=meta['source'],
similarity=sim) for doc, meta, sim in unique_results
]
def generate_prompt(query: str, context_docs: List[SearchResult]) -> str:
"""Generate prompt for DeepSeek R1."""
# Format context with clear document boundaries
context = "\n\n".join([
f"[Document {i+1} begin]\nSource: {doc.source}\nContent: {doc.content}\n[Document {i+1} end]"
for i, doc in enumerate(context_docs)
])
return f"""# The following contents are search results from legal documents and related discussions:
{context}
You are a helpful AI assistant analyzing legal documents and related content. When responding, please follow these guidelines:
- In the search results provided, each document is formatted as [Document X begin]...[Document X end], where X represents the numerical index of each document.
- Cite your documents using [citation:X] format where X is the document number, placing citations immediately after the relevant information.
- Include citations throughout your response, not just at the end.
- If information comes from multiple documents, use multiple citations like [citation:1][citation:2].
- Not all search results may be relevant - evaluate and use only pertinent information.
- Structure longer responses into clear paragraphs or sections for readability.
- If you cannot find the answer in the provided documents, say so - do not make up information.
- Some documents may be informal discussions or reddit posts - adjust your interpretation accordingly.
- Put citation as much as possible in your response. If you mention two Documents, mention as Document X and Document Y, instead of Document X and Y.
First, explain your thinking process between <think> tags.
Then provide your final answer after the thinking process.
# Question:
{query}
Let's approach this step by step:"""
async def query_llm(prompt: str, temperature: float = 0.7) -> tuple[str, str]:
"""Query DeepSeek R1 through vLLM endpoint and return thinking process and answer."""
global generator_endpoint
try:
response = requests.post(f"{generator_endpoint}/v1/chat/completions",
json={
"model": "/tmp/generation_model",
"messages": [{
"role": "user",
"content": prompt
}],
"temperature": temperature,
"max_tokens": 2048,
"stop": None
},
timeout=120)
response.raise_for_status()
logger.info(f"Response: {response.json()}")
full_response = response.json(
)['choices'][0]['message']['content'].strip()
# Split response into thinking process and answer
parts = full_response.split("</think>")
if len(parts) > 1:
thinking = parts[0].replace("<think>", "").strip()
answer = parts[1].strip()
else:
thinking = ""
answer = full_response
return thinking, answer
except Exception as e:
logger.error(f"Error querying LLM: {str(e)}")
raise HTTPException(status_code=500,
detail="Error querying language model")
@app.post('/rag', response_model=RAGResponse)
async def rag_query(request: QueryRequest):
"""RAG endpoint combining vector search with DeepSeek R1."""
try:
# Encode query
query_embedding = encode_query(request.query)
# Get relevant documents
results = query_collection(query_embedding, request.n_results)
# Generate prompt
prompt = generate_prompt(request.query, results)
# Get LLM response
thinking, answer = await query_llm(prompt, request.temperature)
return RAGResponse(answer=answer,
sources=results,
thinking_process=thinking)
except Exception as e:
logger.error(f"Error processing RAG query: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get('/health')
async def health_check():
"""Health check endpoint."""
return {
'status': 'healthy',
'collection_size': collection.count() if collection else 0
}
@app.get('/', response_class=HTMLResponse)
async def get_search_page():
"""Serve a simple search interface."""
template_path = os.path.join(os.path.dirname(__file__), 'templates',
'index.html')
try:
with open(template_path, 'r') as f:
return f.read()
except FileNotFoundError:
raise HTTPException(
status_code=500,
detail=f"Template file not found at {template_path}")
@app.post('/documents', response_model=DocumentsOnlyResponse)
async def get_documents(request: DocumentsOnlyRequest):
"""Get relevant documents for a query without LLM processing."""
try:
# Encode query
query_embedding = encode_query(request.query)
# Get relevant documents
results = query_collection(query_embedding, request.n_results)
# Generate a unique request ID
request_id = str(uuid.uuid4())
# Store the request data for later LLM processing
active_requests[request_id] = {
"query": request.query,
"results": results,
"status": "documents_ready",
"timestamp": time.time()
}
# Clean up old requests (older than 30 minutes)
current_time = time.time()
expired_requests = [
req_id for req_id, data in active_requests.items()
if current_time - data["timestamp"] > 1800
]
for req_id in expired_requests:
active_requests.pop(req_id, None)
return DocumentsOnlyResponse(sources=results, request_id=request_id)
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post('/process_llm', response_model=LLMStatusResponse)
async def process_llm(request: StartLLMRequest):
"""Process a query with the LLM using previously retrieved documents."""
request_id = request.request_id
# Check if the request exists and is ready for LLM processing
if request_id not in active_requests or active_requests[request_id][
"status"] != "documents_ready":
raise HTTPException(status_code=404,
detail="Request not found or documents not ready")
# Mark the request as in progress
active_requests[request_id]["status"] = "llm_processing"
try:
# Get stored data
query = active_requests[request_id]["query"]
results = active_requests[request_id]["results"]
# Generate prompt
prompt = generate_prompt(query, results)
# Get LLM response
thinking, answer = await query_llm(prompt, request.temperature)
# Store the response and mark as completed
active_requests[request_id]["status"] = "completed"
active_requests[request_id]["thinking"] = thinking
active_requests[request_id]["answer"] = answer
active_requests[request_id]["timestamp"] = time.time()
return LLMStatusResponse(status="completed",
answer=answer,
thinking_process=thinking)
except Exception as e:
# Mark as error
active_requests[request_id]["status"] = "error"
active_requests[request_id]["error"] = str(e)
active_requests[request_id]["timestamp"] = time.time()
logger.error(f"Error processing LLM request: {str(e)}")
return LLMStatusResponse(status="error", error=str(e))
@app.get('/llm_status/{request_id}', response_model=LLMStatusResponse)
async def get_llm_status(request_id: str):
"""Get the status of an LLM request."""
if request_id not in active_requests:
raise HTTPException(status_code=404, detail="Request not found")
request_data = active_requests[request_id]
if request_data["status"] == "completed":
return LLMStatusResponse(status="completed",
answer=request_data["answer"],
thinking_process=request_data["thinking"])
elif request_data["status"] == "error":
return LLMStatusResponse(status="error",
error=request_data.get("error",
"Unknown error"))
else:
return LLMStatusResponse(status="pending")
def main():
parser = argparse.ArgumentParser(description='Serve RAG system')
parser.add_argument('--host',
type=str,
default='0.0.0.0',
help='Host to serve on')
parser.add_argument('--port',
type=int,
default=8000,
help='Port to serve on')
parser.add_argument('--collection-name',
type=str,
default='legal_docs',
help='ChromaDB collection name')
parser.add_argument('--persist-dir',
type=str,
default='/vectordb/chroma',
help='Directory where ChromaDB is persisted')
parser.add_argument('--generator-endpoint',
type=str,
required=True,
help='Endpoint for text generation service')
parser.add_argument('--embed-endpoint',
type=str,
required=True,
help='Endpoint for embeddings service')
args = parser.parse_args()
# Initialize global variables
global collection, generator_endpoint, embed_endpoint
# Set endpoints
generator_endpoint = args.generator_endpoint.rstrip('/')
embed_endpoint = args.embed_endpoint.rstrip('/')
# Initialize ChromaDB
logger.info(f'Connecting to ChromaDB at {args.persist_dir}')
client = chromadb.PersistentClient(path=args.persist_dir)
try:
collection = client.get_collection(name=args.collection_name)
logger.info(f'Connected to collection: {args.collection_name}')
logger.info(f'Total documents in collection: {collection.count()}')
except ValueError as e:
logger.error(f'Error: {str(e)}')
logger.error(
'Make sure the collection exists and the persist_dir is correct.')
raise
# Start server
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == '__main__':
main()
scripts/templates/index.html
<!DOCTYPE html>
<html>
<head>
<title>SkyPilot Legal RAG</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
line-height: 1.6;
background-color: #f5f5f5;
color: #333;
min-height: 100vh;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
}
.search-container {
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
text-align: center;
}
h1 {
color: #2c3e50;
margin-bottom: 1.5rem;
font-size: 2.5rem;
}
.search-box {
display: flex;
gap: 10px;
max-width: 800px;
margin: 0 auto;
}
input {
flex: 1;
padding: 12px 20px;
border: 2px solid #e0e0e0;
border-radius: 25px;
font-size: 16px;
transition: all 0.3s ease;
}
input:focus {
outline: none;
border-color: #3498db;
box-shadow: 0 0 5px rgba(52, 152, 219, 0.3);
}
button {
padding: 12px 30px;
background: #3498db;
color: white;
border: none;
border-radius: 25px;
cursor: pointer;
font-size: 16px;
transition: background 0.3s ease;
}
button:hover {
background: #2980b9;
}
.results-container {
display: grid;
gap: 2rem;
}
.result-section {
background: white;
border-radius: 10px;
padding: 1.5rem;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.section-title {
color: #2c3e50;
margin-bottom: 1rem;
font-size: 1.5rem;
border-bottom: 2px solid #e0e0e0;
padding-bottom: 0.5rem;
}
.source-document {
background: #f8f9fa;
padding: 1rem;
margin-bottom: 1rem;
border-radius: 5px;
border-left: 4px solid #3498db;
white-space: pre-wrap;
}
.source-header {
font-weight: bold;
color: #2c3e50;
margin-bottom: 0.5rem;
}
.source-url {
color: #3498db;
text-decoration: underline;
word-break: break-all;
margin-bottom: 0.5rem;
}
.thinking-process {
background: #fff3e0;
padding: 1rem;
border-radius: 5px;
border-left: 4px solid #ff9800;
white-space: pre-wrap;
}
.final-answer {
background: #e8f5e9;
padding: 1rem;
border-radius: 5px;
border-left: 4px solid #4caf50;
white-space: pre-wrap;
}
#loading {
display: none;
text-align: center;
margin: 2rem 0;
font-size: 1.2rem;
color: #666;
}
.similarity-score {
color: #666;
font-size: 0.9rem;
margin-top: 0.5rem;
}
.citation {
color: #3498db;
cursor: pointer;
text-decoration: underline;
}
.citation:hover {
color: #2980b9;
}
.highlighted-source {
animation: highlight 2s;
}
@keyframes highlight {
0% { background-color: #fff3cd; }
100% { background-color: #f8f9fa; }
}
.disclaimer {
color: #666;
font-size: 1rem;
margin-bottom: 2rem;
font-style: italic;
}
.thinking-indicator {
text-align: center;
background: #e0f7fa;
padding: 1.5rem;
border-radius: 10px;
margin-bottom: 1rem;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
border-left: 4px solid #00acc1;
}
.thinking-indicator h3 {
margin-bottom: 0.5rem;
color: #00838f;
}
.loading-bar {
height: 6px;
width: 200px;
background-color: #e0e0e0;
border-radius: 3px;
overflow: hidden;
position: relative;
}
.loading-bar::after {
content: '';
position: absolute;
left: -100%;
height: 100%;
width: 50%;
background-color: #00acc1;
animation: loading 1.5s infinite ease-in-out;
}
@keyframes loading {
0% { left: -50% }
100% { left: 150% }
}
.github-corner {
position: absolute;
top: 0;
right: 0;
border: 0;
z-index: 999;
}
.github-corner svg {
fill: #24292e;
color: #fff;
width: 80px;
height: 80px;
}
.github-corner:hover .octo-arm {
animation: octocat-wave 560ms ease-in-out;
}
@keyframes octocat-wave {
0%, 100% { transform: rotate(0); }
20%, 60% { transform: rotate(-25deg); }
40%, 80% { transform: rotate(10deg); }
}
@media (max-width: 500px) {
.github-corner svg {
width: 60px;
height: 60px;
}
.github-corner:hover .octo-arm {
animation: none;
}
.github-corner .octo-arm {
animation: octocat-wave 560ms ease-in-out;
}
}
</style>
</head>
<body>
<!-- GitHub Corner -->
<a href="https://github.com/skypilot-org/skypilot/tree/master/llm/rag" class="github-corner" aria-label="Star us on GitHub">
<svg viewBox="0 0 250 250" aria-hidden="true">
<path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path>
<path d="M128.3,109.0 C113.8,99.7 119.0,89.6 119.0,89.6 C122.0,82.7 120.5,78.6 120.5,78.6 C119.2,72.0 123.4,76.3 123.4,76.3 C127.3,80.9 125.5,87.3 125.5,87.3 C122.9,97.6 130.6,101.9 134.4,103.2" fill="currentColor" style="transform-origin: 130px 106px;" class="octo-arm"></path>
<path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z" fill="currentColor" class="octo-body"></path>
</svg>
</a>
<div class="container">
<div class="search-container">
<h1>SkyPilot Legal RAG</h1>
<div class="search-box">
<input type="text" id="searchInput" placeholder="Ask a question about legal documents..."
onkeypress="if(event.key === 'Enter') search()">
<button onclick="search()">Ask</button>
</div>
</div>
<div id="document-search-indicator" class="thinking-indicator" style="display: none;">
<h3>Searching for documents...</h3>
<div class="loading-bar"></div>
</div>
<div id="thinking-indicator" class="thinking-indicator" style="display: none;">
<h3>DeepSeek is thinking...</h3>
<div class="loading-bar"></div>
</div>
<div id="results" class="results-container"></div>
</div>
<script>
function escapeHtml(unsafe) {
return unsafe
.replace(/&/g, "&")
.replace(/</g, "<")
.replace(/>/g, ">")
.replace(/"/g, """)
.replace(/'/g, "'");
}
function highlightSource(docNumber) {
// Remove previous highlights
document.querySelectorAll('.highlighted-source').forEach(el => {
el.classList.remove('highlighted-source');
});
// Add highlight to clicked source
const sourceElement = document.querySelector(`[data-doc-number="${docNumber}"]`);
if (sourceElement) {
sourceElement.classList.add('highlighted-source');
sourceElement.scrollIntoView({ behavior: 'smooth', block: 'center' });
}
}
function processCitations(text) {
// Handle both [citation:X] and Document X formats
return text
.replace(/\[citation:(\d+)\]/g, (match, docNumber) => {
return `<span class="citation" onclick="highlightSource(${docNumber})">[${docNumber}]</span>`;
})
.replace(/Document (\d+)/g, (match, docNumber) => {
return `<span class="citation" onclick="highlightSource(${docNumber})">Document ${docNumber}</span>`;
});
}
async function search() {
const searchInput = document.getElementById('searchInput');
const resultsDiv = document.getElementById('results');
const thinkingIndicator = document.getElementById('thinking-indicator');
const documentSearchIndicator = document.getElementById('document-search-indicator');
if (!searchInput.value.trim()) return;
// Clear previous results and show document search indicator
resultsDiv.innerHTML = '';
documentSearchIndicator.style.display = 'flex';
thinkingIndicator.style.display = 'none';
// Step 1: Get documents first
try {
// First call to get the documents
const docsResponse = await fetch('/documents', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json'
},
body: JSON.stringify({
query: searchInput.value.trim(),
n_results: 10
})
});
if (!docsResponse.ok) {
const errorData = await docsResponse.json();
throw new Error(errorData.detail || 'Failed to retrieve documents');
}
const docsResult = await docsResponse.json();
const requestId = docsResult.request_id;
// Hide document search indicator and show DeepSeek indicator
documentSearchIndicator.style.display = 'none';
thinkingIndicator.style.display = 'flex';
// Display the documents first
let sourcesHtml = '<div class="result-section"><h2 class="section-title">Source Documents</h2>';
docsResult.sources.forEach((source, index) => {
sourcesHtml += `
<div class="source-document" data-doc-number="${index + 1}">
<div class="source-header">Source: ${escapeHtml(source.source)}</div>
<div class="source-url">URL: ${escapeHtml(source.name)}</div>
<div>${escapeHtml(source.content)}</div>
<div class="similarity-score">Similarity: ${(source.similarity * 100).toFixed(1)}%</div>
</div>
`;
});
sourcesHtml += '</div>';
// Display sources
resultsDiv.innerHTML = sourcesHtml;
// Step 2: Start the LLM reasoning process in the background
const llmResponse = await fetch('/process_llm', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json'
},
body: JSON.stringify({
request_id: requestId,
temperature: 0.7
})
});
if (!llmResponse.ok) {
const errorData = await llmResponse.json();
throw new Error(errorData.detail || 'LLM processing failed');
}
const llmResult = await llmResponse.json();
// Handle different response statuses
if (llmResult.status === "completed") {
// Hide thinking indicator
thinkingIndicator.style.display = 'none';
// Add thinking process and answer at the top
const thinkingHtml = `
<div class="result-section">
<h2 class="section-title">Thinking Process</h2>
<div class="thinking-process">${processCitations(escapeHtml(llmResult.thinking_process))}</div>
</div>
`;
const answerHtml = `
<div class="result-section">
<h2 class="section-title">Final Answer</h2>
<div class="final-answer">${processCitations(escapeHtml(llmResult.answer)).replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')}</div>
</div>
`;
// Insert before the sources section
const sourcesSection = document.querySelector('.result-section');
sourcesSection.insertAdjacentHTML('beforebegin', answerHtml + thinkingHtml);
} else if (llmResult.status === "error") {
// Handle error case
thinkingIndicator.style.display = 'none';
resultsDiv.insertAdjacentHTML('afterbegin', `
<div class="result-section" style="color: #e74c3c;">
<h2 class="section-title">Error</h2>
<p>${llmResult.error || "An error occurred while processing the query"}</p>
</div>
`);
} else {
// Handle if status is still pending (should not happen with direct call)
pollForResults(requestId);
}
} catch (error) {
documentSearchIndicator.style.display = 'none';
thinkingIndicator.style.display = 'none';
resultsDiv.innerHTML = `
<div class="result-section" style="color: #e74c3c;">
<h2 class="section-title">Error</h2>
<p>${error.message}</p>
</div>
`;
}
}
// Function to poll for results if needed
async function pollForResults(requestId) {
const maxAttempts = 60; // 5 minutes at 5-second intervals
let attempts = 0;
const thinkingIndicator = document.getElementById('thinking-indicator');
const poll = async () => {
if (attempts >= maxAttempts) {
thinkingIndicator.style.display = 'none';
const errorHtml = `
<div class="result-section" style="color: #e74c3c;">
<h2 class="section-title">Timeout</h2>
<p>Request timed out after 5 minutes. Please try again.</p>
</div>
`;
document.getElementById('results').insertAdjacentHTML('afterbegin', errorHtml);
return;
}
attempts++;
try {
const response = await fetch(`/llm_status/${requestId}`);
if (!response.ok) {
throw new Error("Failed to retrieve status");
}
const result = await response.json();
if (result.status === "completed") {
// Hide thinking indicator
thinkingIndicator.style.display = 'none';
// Add thinking process and answer
const thinkingHtml = `
<div class="result-section">
<h2 class="section-title">Thinking Process</h2>
<div class="thinking-process">${processCitations(escapeHtml(result.thinking_process))}</div>
</div>
`;
const answerHtml = `
<div class="result-section">
<h2 class="section-title">Final Answer</h2>
<div class="final-answer">${processCitations(escapeHtml(result.answer)).replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')}</div>
</div>
`;
// Insert at the beginning of results
const sourcesSection = document.querySelector('.result-section');
sourcesSection.insertAdjacentHTML('beforebegin', answerHtml + thinkingHtml);
} else if (result.status === "error") {
// Handle error
thinkingIndicator.style.display = 'none';
const errorHtml = `
<div class="result-section" style="color: #e74c3c;">
<h2 class="section-title">Error</h2>
<p>${result.error || "An error occurred during processing"}</p>
</div>
`;
document.getElementById('results').insertAdjacentHTML('afterbegin', errorHtml);
} else {
// Still processing, wait and try again
setTimeout(poll, 5000); // Check again after 5 seconds
}
} catch (error) {
console.error("Error polling for results:", error);
setTimeout(poll, 5000); // Try again after 5 seconds
}
};
// Start polling
poll();
}
</script>
</body>
</html>
serve_rag.yaml
name: serve-legal-rag
workdir: .
resources:
accelerators: {L4:4, L40S:4}
memory: 32+
ports:
- 8000
any_of:
- use_spot: true
- use_spot: false
envs:
EMBEDDING_MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"
GENERATION_MODEL_NAME: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
VECTORDB_BUCKET_NAME: sky-rag-vectordb
VECTORDB_BUCKET_ROOT: /vectordb
file_mounts:
${VECTORDB_BUCKET_ROOT}:
name: ${VECTORDB_BUCKET_NAME}
# this needs to be the same as in build_vectordb.yaml
mode: MOUNT
setup: |
# Install dependencies for RAG service
pip install numpy pandas sentence-transformers requests tqdm
pip install fastapi uvicorn pydantic chromadb
# Install dependencies for vLLM
pip install transformers==4.48.1 vllm==0.6.6.post1 hf_transfer
run: |
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/generation_model $GENERATION_MODEL_NAME
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/embedding_model $EMBEDDING_MODEL_NAME
# Start vLLM generationservice in background
CUDA_VISIBLE_DEVICES=0,1,2 python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8002 \
--model /tmp/generation_model \
--max-model-len 28816 \
--tensor-parallel-size 2 \
--task generate &
# Wait for vLLM to start
echo "Waiting for vLLM service to be ready..."
while ! curl -s http://localhost:8002/health > /dev/null; do
sleep 5
echo "Still waiting for vLLM service..."
done
echo "vLLM service is ready!"
# Start vLLM embeddings service in background
CUDA_VISIBLE_DEVICES=3 python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8003 \
--model /tmp/embedding_model \
--max-model-len 4096 \
--task embed &
# Wait for vLLM embeddings service to start
echo "Waiting for vLLM embeddings service to be ready..."
while ! curl -s http://localhost:8003/health > /dev/null; do
sleep 5
echo "Still waiting for vLLM embeddings service..."
done
echo "vLLM embeddings service is ready!"
# Start RAG service
python scripts/serve_rag.py \
--collection-name legal_docs \
--persist-dir /vectordb/chroma \
--generator-endpoint http://localhost:8002 \
--embed-endpoint http://localhost:8003
service:
replicas: 1
readiness_probe:
path: /health