Source: llm/batch_inference
Large-Scale AI Batch Inference: 9x Faster Embedding Generation#
Large-Scale Embedding Generation for Text#
As the volume of text data grows, the need for efficient and powerful embedding generation methods becomes critical. Embedding generation is at the heart of modern AI applications, from recommendation systems to retrieval augmented generation (RAG).
However, running batch inference for embedding generation on million/trillion-scale datasets is not trivial. After spending days developing a well-tuned batch inference script, getting hundreds of GPUs and running jobs on those GPUs in parallel are still a huge pain.
In particular:
Scalability: Modern apps can deal with millions or billions of text records, making traditional approaches slow or impractical.
Availability: GPU quota and availability constraints in a single region severely limit processing capacity.
Cost: On-demand GPU instances for large-scale processing can be prohibitively expensive.
SkyPilot streamlines the process of running such large-scale jobs in the cloud. It abstracts away much of the complexity of managing infrastructure and helps you run compute-intensive tasks efficiently and cost-effectively through managed jobs across multiple regions.
Performance Highlights#
By leveraging SkyPilot’s multi-region approach, we achieved:
9x More Resources: Access to 406 GPUs across 12 regions (vs. only ~47 in a single region)
10x Faster Processing: Reduced processing time from 20+ hours to just 2 hours
61% Cost Reduction: Lowered costs from $710 to $277.07 through spot instance usage
Enhanced Reliability: Automatic recovery from spot instance preemptions
Compute Embeddings from Text Data with LLM Models#
You need to convert text into vector representations (embeddings) so they can be stored in a vector database.
We use the book partition of the Amazon reviews 2023, containing ~30M Amazon reviews for books, and generating the embeddings for the reviews with the state-of-the-art specialized embedding LLM Alibaba-NLP/gte-Qwen2-7B-instruct
, one of the top embedding models on the MTEB leaderboard.
Use the following command to launch a job that processes your text dataset and computes embeddings:
python3 batch_compute_vectors.py
This will automatically find available machines across multiple regions to compute the vectors. The script partitions the workload evenly using a stride approach, ensuring each worker processes documents that are distributed throughout the dataset.
Monitor the progress#
You can use sky jobs queue
and sky jobs dashboard
to see the status of jobs. Alternatively, you can monitor the progress via
sky launch -n monitor monitor_progress.yaml
and get the IP address via
export ENDPOINT=$(sky status --ip monitor)
and visit http:$ENDPOINT:8000
in the browser.
Learn More#
For a complete walkthrough of this case study, including detailed performance metrics and implementation insights, read our blog post on large-scale embedding generation.
Included files#
batch_compute_vectors.py
"""
Use skypilot to launch managed jobs that will run the embedding calculation.
This script is responsible for:
1. Launching a monitoring service cluster
2. Splitting the input dataset among several workers
3. Launching worker clusters with unique worker IDs
"""
#!/usr/bin/env python3
import argparse
import os
import time
import uuid
import sky
def calculate_job_range(start_idx: int, end_idx: int, job_rank: int,
total_jobs: int) -> tuple[int, int]:
"""Calculate the range of indices this job should process.
Args:
start_idx: Global start index
end_idx: Global end index
job_rank: Current job's rank (0-based)
total_jobs: Total number of jobs
Returns:
Tuple of [job_start_idx, job_end_idx)
"""
total_range = end_idx - start_idx
chunk_size = total_range // total_jobs
remainder = total_range % total_jobs
# Distribute remainder across first few jobs
job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder)
if job_rank < remainder:
chunk_size += 1
job_end = job_start + chunk_size
return job_start, job_end
def main():
parser = argparse.ArgumentParser(
description='Launch batch CLIP inference jobs')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Global start index in dataset')
parser.add_argument(
'--end-idx',
type=int,
default=29475453,
#29475453 is the last index of the review dataset
help='Global end index in dataset, not inclusive')
parser.add_argument('--num-jobs',
type=int,
default=500,
help='Number of jobs to partition the work across')
parser.add_argument('--skip-monitor',
action='store_true',
help='Skip launching the monitoring service')
parser.add_argument('--bucket-name',
type=str,
default='sky-embeddings',
help='Name of the bucket to store embeddings')
parser.add_argument(
'--partition-method',
type=str,
choices=['chunk', 'stride'],
default='stride',
help=
'Method to partition data: chunk (contiguous) or stride (interleaved)')
args = parser.parse_args()
# Load the worker task template
task = sky.Task.from_yaml('compute_text_vectors.yaml')
# Launch jobs for each partition
for job_rank in range(args.num_jobs):
# Create a unique worker ID
worker_id = f"worker_{job_rank}"
# Update environment variables based on partition method
env_vars = {
'WORKER_ID': worker_id,
'PARTITION_METHOD': args.partition_method,
'WORKER_RANK': str(job_rank),
'TOTAL_WORKERS': str(args.num_jobs),
'GLOBAL_START_IDX': str(args.start_idx),
'GLOBAL_END_IDX': str(args.end_idx),
}
# If using chunk method, also provide start_idx and end_idx
if args.partition_method == 'chunk':
job_start, job_end = calculate_job_range(args.start_idx,
args.end_idx, job_rank,
args.num_jobs)
env_vars['START_IDX'] = str(job_start)
env_vars['END_IDX'] = str(job_end)
job_name = f'vector-compute-{job_start}-{job_end}'
else:
# For stride method, we use the global start/end and let the worker handle striding
job_name = f'vector-compute-worker-{job_rank}'
task_copy = task.update_envs(env_vars)
print(f"Launching {job_name} with {worker_id}...")
sky.jobs.launch(
task_copy,
name=job_name,
)
if __name__ == '__main__':
main()
compute_text_vectors.yaml
name: batch-inference-compute-text-vectors
workdir: .
resources:
cpus: 4
accelerators:
L4: 1
cloud: aws
any_of:
- use_spot: true
- use_spot: false
envs:
START_IDX: 0 # Will be overridden by batch launcher script
END_IDX: 10000 # Will be overridden by batch launcher script
MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"
DATASET_NAME: "McAuley-Lab/Amazon-Reviews-2023"
DATASET_CONFIG: "raw_review_Books"
EMBEDDINGS_BUCKET_NAME: sky-text-embeddings
WORKER_ID: ''
file_mounts:
/output:
name: ${EMBEDDINGS_BUCKET_NAME}
mode: MOUNT
setup: |
# Install dependencies for vLLM
pip install transformers==4.48.1 vllm==0.6.6.post1
# Install dependencies for embedding computation
pip install numpy pandas requests tqdm datasets nltk
pip install torch torchvision aiohttp
pip install hf_transfer pyarrow
run: |
# Initialize and download the model
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/model $MODEL_NAME
# Create metrics directory for monitoring service
mkdir -p /output/metrics
# Set worker ID for metrics tracking
if [ -z "$WORKER_ID" ]; then
export WORKER_ID="worker_$(date +%s)_$(hostname)"
echo "Generated worker ID: $WORKER_ID"
fi
# Start vLLM service in background with token counting enabled
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--model /tmp/model \
--max-model-len 3072 \
--task embed > /dev/null 2>&1 &
# Wait for vLLM service to be ready by checking the health endpoint
echo "Waiting for vLLM service to be ready..."
while ! curl -s http://localhost:8000/health > /dev/null; do
sleep 5
echo "Still waiting for vLLM service..."
done
echo "vLLM service is ready!"
# Process the assigned range of documents
echo "Processing documents from $START_IDX to $END_IDX"
# Process text documents and track token metrics
python scripts/text_vector_processor.py \
--output-path "/output/embeddings_${START_IDX}_${END_IDX}.parquet" \
--start-idx $START_IDX \
--end-idx $END_IDX \
--chunk-size 512 \
--chunk-overlap 50 \
--vllm-endpoint http://localhost:8000 \
--batch-size 32 \
--model-name /tmp/model \
--dataset-name $DATASET_NAME \
--dataset-config $DATASET_CONFIG
# Print tokens statistics summary from metrics
echo "Embedding generation complete. Token statistics saved to metrics."
# Clean up vLLM service
pkill -f "python -m vllm.entrypoints.openai.api_server"
echo "vLLM service has been stopped"
monitor_progress.yaml
name: batch-inference-monitor-progress
workdir: .
resources:
cpus: 2
memory: 8+
cloud: aws
ports:
- 8000
envs:
# make sure this is the same as the source in compute_vectors.yaml
EMBEDDINGS_BUCKET_NAME: sky-text-embeddings
file_mounts:
/output:
name: ${EMBEDDINGS_BUCKET_NAME}
# this needs to be the same as the source in compute_vectors.yaml
mode: MOUNT
store: s3
setup: |
pip install fastapi uvicorn aiofiles
pip install pandas pyarrow plotly
run: |
python scripts/monitor_progress.py --metrics-dir /output/metrics
scripts/base_vector_processor.py
from abc import ABC
from abc import abstractmethod
import asyncio
import json
import logging
import os
from pathlib import Path
import pickle
import time
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
class BaseVectorProcessor(ABC):
"""Base class for processing data and computing vector embeddings.
This abstract class provides a common framework for computing embeddings
for different types of data (images, text) and datasets.
"""
def __init__(self,
output_path: str,
dataset_name: str,
split: str = 'train',
streaming: bool = True,
batch_size: int = 32,
checkpoint_size: int = 100,
start_idx: int = 0,
end_idx: Optional[int] = None,
max_preprocessing_tasks: int = 10):
"""Initialize the base vector processor.
Args:
output_path: Path to save the computed vectors
dataset_name: Name of the dataset to process
split: Dataset split to use
streaming: Whether to stream the dataset
batch_size: Size of batches for processing
checkpoint_size: Number of items to process before saving
start_idx: Starting index in the dataset
end_idx: Ending index in the dataset
max_preprocessing_tasks: Maximum number of concurrent preprocessing tasks
"""
self.output_path = Path(output_path) # Convert to Path object
self.batch_size = batch_size
self.checkpoint_size = checkpoint_size
self.start_idx = start_idx
self.end_idx = end_idx
self._current_batch = []
# Dataset attributes
self.dataset_name = dataset_name
self.split = split
self.streaming = streaming
# Model attributes
self.model = None
self.partition_counter = 0
# Control parallel preprocessing
self.preprocessing_semaphore = asyncio.Semaphore(
max_preprocessing_tasks)
# Progress tracking
self.metrics_path = Path(output_path).parent / 'metrics'
self.metrics_path.mkdir(exist_ok=True)
self.worker_id = os.getenv('WORKER_ID', 'unknown')
self.metrics_file = self.metrics_path / f'worker_{self.worker_id}.json'
self.metrics_history_file = self.metrics_path / f'worker_{self.worker_id}_history.json'
self.processed_count = 0
self.failed_count = 0
self.start_time = time.time()
self.last_update_time = self.start_time
self.session_id = f"{self.worker_id}_{int(self.start_time)}"
# Load existing history if available
self.metrics_history = self._load_metrics_history()
@abstractmethod
async def setup_model(self):
"""Set up the model for computing embeddings."""
pass
@abstractmethod
async def get_dataset_iterator(self) -> AsyncIterator[Tuple[int, Any]]:
"""Get an iterator over the dataset."""
pass
@abstractmethod
async def _preprocess_input(self, item: Any) -> Optional[Any]:
"""Preprocess a single input item."""
pass
@abstractmethod
async def do_batch_processing(self,
batch: List[Any]) -> List[Tuple[int, Any]]:
"""Process a batch of preprocessed inputs."""
pass
@abstractmethod
def save_results_to_parquet(self, results: List):
"""Save results to parquet file."""
pass
async def do_data_loading(self) -> AsyncIterator[Tuple[int, Any]]:
"""Load and preprocess inputs in parallel."""
if self.model is None:
await self.setup_model()
preprocessing_tasks = []
buffer_size = self.batch_size * 2
async for idx, item in self.get_dataset_iterator():
# Clean up completed tasks when buffer is full
if len(preprocessing_tasks) >= buffer_size:
done, pending = await asyncio.wait(
preprocessing_tasks, return_when=asyncio.FIRST_COMPLETED)
preprocessing_tasks = list(pending)
for task in done:
result = await task
if result is not None:
yield result
# Start new preprocessing task
async def preprocess_with_index(idx, item):
async with self.preprocessing_semaphore:
processed = await self._preprocess_input(item)
if processed is not None:
return (idx, processed)
return None
task = asyncio.create_task(preprocess_with_index(idx, item))
preprocessing_tasks.append(task)
# Wait for and yield remaining results
if preprocessing_tasks:
done = await asyncio.gather(*preprocessing_tasks)
for result in done:
if result is not None:
yield result
def _load_metrics_history(self) -> List[Dict]:
"""Load metrics history from file."""
if self.metrics_history_file.exists():
try:
with open(self.metrics_history_file, 'r') as f:
return eval(f.read())
except Exception as e:
logging.warning(f"Failed to load metrics history: {e}")
return []
def save_metrics_history(self):
"""Save metrics history to file."""
try:
with open(self.metrics_history_file, 'w') as f:
json.dump(self.metrics_history, f)
except Exception as e:
logging.warning(f"Failed to save metrics history: {e}")
def update_metrics(self):
"""Update processing metrics."""
current_time = time.time()
elapsed = current_time - self.start_time
elapsed_since_update = current_time - self.last_update_time
# Only compute throughput if some time has elapsed
if elapsed_since_update > 0:
processing_rate = self.processed_count / elapsed if elapsed > 0 else 0
metrics = {
'worker_id': self.worker_id,
'session_id': self.session_id,
'processed_count': self.processed_count,
'failed_count': self.failed_count,
'elapsed_seconds': elapsed,
'items_per_second': processing_rate,
'start_idx': self.start_idx,
'current_idx': self.start_idx + self.processed_count,
'end_idx': self.end_idx,
'timestamp': current_time,
'status': 'running',
'partition_counter': self.partition_counter
}
# Add to history
self.metrics_history.append(metrics)
# Save to files
try:
with open(self.metrics_file, 'w') as f:
import json
json.dump(metrics, f)
self.save_metrics_history()
except Exception as e:
logging.warning(f"Failed to save metrics: {e}")
self.last_update_time = current_time
async def find_existing_progress(self) -> Tuple[int, int]:
"""Find the latest progress from previous runs."""
max_idx = self.start_idx - 1
partition_counter = 0
# Check metrics file first
if self.metrics_file.exists():
try:
with open(self.metrics_file, 'r') as f:
import json
metrics = json.load(f)
if 'current_idx' in metrics:
max_idx = max(max_idx, metrics['current_idx'])
if 'partition_counter' in metrics:
partition_counter = metrics['partition_counter']
logging.info(
f"Recovered progress from metrics: idx={max_idx}, partition={partition_counter}"
)
except Exception as e:
logging.warning(f"Failed to recover from metrics file: {e}")
# Check for existing parquet files as backup
try:
prefix = self.output_path.stem
parent = self.output_path.parent
existing_files = list(parent.glob(f"{prefix}*.parquet"))
for file_path in existing_files:
try:
df = pd.read_parquet(file_path)
if 'idx' in df.columns:
max_idx = max(max_idx, df['idx'].max())
logging.info(
f"Found existing progress in {file_path.name}: max_idx={max_idx}"
)
# Extract partition number from filename
import re
match = re.search(r'_part(\d+)\.parquet$', file_path.name)
if match:
part_num = int(match.group(1))
partition_counter = max(partition_counter, part_num + 1)
except Exception as e:
logging.warning(
f"Failed to read parquet file {file_path}: {e}")
except Exception as e:
logging.warning(f"Failed to check for existing parquet files: {e}")
return max_idx, partition_counter
async def run(self):
"""Run the batch processing pipeline with recovery support."""
try:
# Initialize the model
if self.model is None:
await self.setup_model()
# Find existing progress and recover state
resume_idx, self.partition_counter = await self.find_existing_progress(
)
self.start_idx = max(self.start_idx, resume_idx + 1)
logging.info(
f'Starting processing from index {self.start_idx} (partition {self.partition_counter})'
)
# Record start event in history
start_metrics = {
'worker_id': self.worker_id,
'session_id': self.session_id,
'event': 'start',
'start_idx': self.start_idx,
'end_idx': self.end_idx,
'timestamp': time.time(),
'status': 'starting'
}
self.metrics_history.append(start_metrics)
self.save_metrics_history() # Explicitly save history for recovery
self.update_metrics() # Also save current state
results = []
async for idx, input_data in self.do_data_loading():
self._current_batch.append((idx, input_data))
if len(self._current_batch) >= self.batch_size:
batch_results = await self.do_batch_processing(
self._current_batch)
results.extend(batch_results)
self._current_batch = []
if len(results) >= self.checkpoint_size:
self.save_results_to_parquet(results)
results.clear()
# Save metrics history at each checkpoint for recovery
self.update_metrics()
# Process any remaining items in the batch
if self._current_batch:
batch_results = await self.do_batch_processing(
self._current_batch)
results.extend(batch_results)
# Write the final partition if there are any leftover results
if results:
self.save_results_to_parquet(results)
# Write final metrics
end_metrics = {
'worker_id': self.worker_id,
'session_id': self.session_id,
'event': 'end',
'processed_count': self.processed_count,
'failed_count': self.failed_count,
'start_idx': self.start_idx,
'end_idx': self.end_idx,
'timestamp': time.time(),
'status': 'completed'
}
self.metrics_history.append(end_metrics)
self.update_metrics()
self.save_metrics_history()
logging.info(
f'Completed processing {self.processed_count} items ({self.failed_count} failed)'
)
except Exception as e:
logging.error(f"Error during batch processing: {e}", exc_info=True)
# Record error in metrics
error_metrics = {
'worker_id': self.worker_id,
'session_id': self.session_id,
'event': 'error',
'error': str(e),
'timestamp': time.time(),
'status': 'error'
}
self.metrics_history.append(error_metrics)
self.save_metrics_history()
raise
scripts/monitor_progress.py
"""
Monitoring service for CLIP vector computation workers.
Aggregates metrics from all workers and serves a dashboard.
"""
import asyncio
from collections import defaultdict
from datetime import datetime
import json
from pathlib import Path
import time
from typing import DefaultDict, Dict, List
import aiofiles
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import HTMLResponse
import uvicorn
app = FastAPI()
class MonitoringService:
def __init__(self, metrics_dir: str, history_window: int = 3600):
self.metrics_dir = Path(metrics_dir)
self.last_update = {}
self.worker_metrics = {}
self.history_window = history_window # Keep 1 hour of history by default
# Track historical throughput data
self.throughput_history: DefaultDict[str,
List[Dict]] = defaultdict(list)
self.last_processed_count: Dict[str, int] = {}
# Track token throughput
self.token_throughput_history: DefaultDict[
str, List[Dict]] = defaultdict(list)
self.last_token_count: Dict[str, int] = {}
# Worker history tracking
self.worker_history: DefaultDict[str, List[Dict]] = defaultdict(list)
self.worker_sessions: DefaultDict[str, List[str]] = defaultdict(list)
self.aggregate_metrics = {
'total_processed': 0,
'total_failed': 0,
'total_workers': 0,
'active_workers': 0,
'completed_workers': 0,
'failed_workers': 0,
'overall_progress': 0.0,
'overall_throughput': 0.0,
'recent_throughput': 0.0,
'overall_token_throughput': 0.0,
'recent_token_throughput': 0.0,
'total_tokens': 0,
'estimated_time_remaining': None,
'total_restarts': 0
}
def update_throughput_history(self, worker_id: str, metrics: Dict):
"""Update historical throughput data for a worker."""
current_time = time.time()
current_count = metrics['processed_count']
# Get the session ID from metrics, defaulting to a generated one if not present
session_id = metrics.get('session_id',
f"unknown_{worker_id}_{int(current_time)}")
# Store both the cumulative progress and throughput data
if worker_id in self.last_processed_count:
time_diff = current_time - self.throughput_history[worker_id][-1][
'timestamp']
count_diff = current_count - self.last_processed_count[worker_id]
# Only update if there's a meaningful time difference and count change
if time_diff > 0 and count_diff >= 0: # Avoid division by zero and negative counts
# Calculate throughput from raw metrics data
throughput = count_diff / time_diff
# Calculate overall throughput since the start of this session
session_start_time = None
session_start_count = 0
# Find the start of this session
for point in self.throughput_history[worker_id]:
if point.get('session_id', 'unknown') == session_id:
if session_start_time is None or point[
'timestamp'] < session_start_time:
session_start_time = point['timestamp']
session_start_count = point.get(
'cumulative_count', 0)
overall_throughput = 0
if session_start_time is not None and current_time > session_start_time:
overall_time = current_time - session_start_time
overall_count = current_count - session_start_count
overall_throughput = overall_count / overall_time if overall_time > 0 else 0
# Add new data point with cumulative count and calculated throughputs
self.throughput_history[worker_id].append({
'timestamp': current_time,
'recent_throughput': throughput,
'overall_throughput': overall_throughput,
'cumulative_count': current_count,
'session_id': session_id,
'interval': time_diff,
'count_change': count_diff
})
# Remove old data points outside the history window
cutoff_time = current_time - self.history_window
self.throughput_history[worker_id] = [
point for point in self.throughput_history[worker_id]
if point['timestamp'] > cutoff_time
]
else:
# First data point for this worker - can't calculate throughput yet
self.throughput_history[worker_id].append({
'timestamp': current_time,
'recent_throughput': 0,
'overall_throughput': 0,
'cumulative_count': current_count,
'session_id': session_id,
'interval': 0,
'count_change': 0
})
self.last_processed_count[worker_id] = current_count
# Track token throughput if available in metrics
if 'total_tokens' in metrics:
current_tokens = metrics['total_tokens']
if worker_id in self.last_token_count:
token_diff = current_tokens - self.last_token_count[worker_id]
# Only update if there's a meaningful time difference and token change
if time_diff > 0 and token_diff >= 0:
# Calculate token throughput
token_throughput = token_diff / time_diff
# Calculate overall token throughput since the start of this session
session_token_start = 0
session_start_time = None
# Find the start of this session for tokens
if worker_id in self.token_throughput_history and self.token_throughput_history[
worker_id]:
for point in self.token_throughput_history[worker_id]:
if point.get('session_id', 'unknown') == session_id:
if session_start_time is None or point[
'timestamp'] < session_start_time:
session_start_time = point['timestamp']
session_token_start = point.get(
'cumulative_tokens', 0)
overall_token_throughput = 0
if session_start_time is not None and current_time > session_start_time:
overall_time = current_time - session_start_time
overall_tokens = current_tokens - session_token_start
overall_token_throughput = overall_tokens / overall_time if overall_time > 0 else 0
# Add new data point for token throughput
self.token_throughput_history[worker_id].append({
'timestamp': current_time,
'recent_token_throughput': token_throughput,
'overall_token_throughput': overall_token_throughput,
'cumulative_tokens': current_tokens,
'session_id': session_id,
'interval': time_diff,
'token_change': token_diff
})
# Remove old data points outside history window
self.token_throughput_history[worker_id] = [
point
for point in self.token_throughput_history[worker_id]
if point['timestamp'] > cutoff_time
]
else:
# First token data point
self.token_throughput_history[worker_id].append({
'timestamp': current_time,
'recent_token_throughput': 0,
'overall_token_throughput': 0,
'cumulative_tokens': current_tokens,
'session_id': session_id,
'interval': 0,
'token_change': 0
})
self.last_token_count[worker_id] = current_tokens
async def read_worker_history(self, worker_id: str):
"""Read the complete history for a worker."""
history_file = self.metrics_dir / f'worker_{worker_id}_history.json'
try:
if history_file.exists():
async with aiofiles.open(history_file, 'r') as f:
content = await f.read()
history = json.loads(content)
# Update worker history
self.worker_history[worker_id] = history
# Extract unique session IDs
sessions = set()
for entry in history:
# Skip entries without a session_id
if 'session_id' in entry:
sessions.add(entry['session_id'])
elif 'timestamp' in entry:
# If session_id is missing but there's a timestamp, create a synthetic session ID
synthetic_session_id = f"unknown_{worker_id}_{int(entry['timestamp'])//3600}"
entry[
'session_id'] = synthetic_session_id # Add it to the entry
sessions.add(synthetic_session_id)
else:
# If both are missing, use a generic unknown ID
entry['session_id'] = f"unknown_{worker_id}"
sessions.add(f"unknown_{worker_id}")
self.worker_sessions[worker_id] = sorted(list(sessions))
return history
return []
except Exception as e:
print(f"Error reading history for worker {worker_id}: {e}")
return []
async def update_metrics(self):
"""Read and aggregate metrics from all worker files."""
try:
# Read all worker metric files
worker_files = list(self.metrics_dir.glob('worker_*.json'))
# Filter out history files
worker_files = [f for f in worker_files if '_history' not in f.name]
new_metrics = {}
total_restarts = 0
for file in worker_files:
try:
async with aiofiles.open(file, 'r') as f:
content = await f.read()
metrics = json.loads(content)
worker_id = metrics['worker_id']
# Ensure session_id exists, generate one if missing
if 'session_id' not in metrics and 'timestamp' in metrics:
metrics[
'session_id'] = f"unknown_{worker_id}_{int(metrics['timestamp'])//3600}"
elif 'session_id' not in metrics:
metrics['session_id'] = f"unknown_{worker_id}"
new_metrics[worker_id] = metrics
# Read worker history
await self.read_worker_history(worker_id)
# Count number of sessions as restarts
total_restarts += max(
0,
len(self.worker_sessions[worker_id]) - 1)
self.update_throughput_history(worker_id, metrics)
except Exception as e:
print(f"Error reading metrics from {file}: {e}")
continue
# Update worker metrics
self.worker_metrics = new_metrics
# Calculate aggregate metrics
total_processed = 0
total_failed = 0
active_workers = 0
completed_workers = 0
failed_workers = 0
total_progress = 0
total_items = 0
total_tokens = 0
# Calculate throughput metrics by aggregating from throughput history
total_recent_throughput = 0
total_overall_throughput = 0
total_recent_token_throughput = 0
total_overall_token_throughput = 0
active_worker_count = 0
for worker_id, metrics in self.worker_metrics.items():
total_processed += metrics['processed_count']
total_failed += metrics.get('failed_count', 0)
total_tokens += metrics.get('total_tokens', 0)
if metrics.get('status') == 'running':
active_workers += 1
active_worker_count += 1
elif metrics.get('status') == 'completed':
completed_workers += 1
elif metrics.get('status') == 'failed':
failed_workers += 1
if metrics.get('end_idx'):
total_items += metrics['end_idx'] - metrics['start_idx']
total_progress += metrics['processed_count']
# Get the most recent throughput data for this worker
if worker_id in self.throughput_history and self.throughput_history[
worker_id]:
latest = sorted(self.throughput_history[worker_id],
key=lambda x: x['timestamp'])[-1]
total_recent_throughput += latest.get(
'recent_throughput', 0)
total_overall_throughput += latest.get(
'overall_throughput', 0)
# Get the most recent token throughput data for this worker
if worker_id in self.token_throughput_history and self.token_throughput_history[
worker_id]:
latest_token = sorted(
self.token_throughput_history[worker_id],
key=lambda x: x['timestamp'])[-1]
total_recent_token_throughput += latest_token.get(
'recent_token_throughput', 0)
total_overall_token_throughput += latest_token.get(
'overall_token_throughput', 0)
# Update aggregate metrics
self.aggregate_metrics.update({
'total_processed': total_processed,
'total_failed': total_failed,
'total_workers': len(self.worker_metrics),
'active_workers': active_workers,
'completed_workers': completed_workers,
'failed_workers': failed_workers,
'overall_progress': (total_progress / total_items *
100) if total_items > 0 else 0,
'overall_throughput': total_overall_throughput,
'recent_throughput': total_recent_throughput,
'overall_token_throughput': total_overall_token_throughput,
'recent_token_throughput': total_recent_token_throughput,
'total_tokens': total_tokens,
'estimated_time_remaining':
(total_items - total_progress) / total_overall_throughput
if total_overall_throughput > 0 else None,
'total_restarts': total_restarts
})
except Exception as e:
print(f"Error updating metrics: {e}")
def get_throughput_chart_data(self) -> Dict:
"""Prepare cumulative progress data for Chart.js from the worker history files."""
# We're going to use the worker history data, which already contains the full history
# of each worker's progress over time, including restarts.
# First, make sure all worker histories are loaded
all_histories = []
for worker_id in self.worker_metrics.keys():
if worker_id not in self.worker_history or not self.worker_history[
worker_id]:
continue # Skip workers with no history
all_histories.extend([
(worker_id, entry) for entry in self.worker_history[worker_id]
])
# If we have no history data, return an empty dataset
if not all_histories:
now = int(time.time())
empty_dataset = {
'label': 'No Progress Data',
'data': [{
'x': (now - 3600) * 1000,
'y': 0
}, {
'x': now * 1000,
'y': 0
}],
'borderColor': '#cccccc',
'backgroundColor': '#cccccc20',
'borderWidth': 2,
'fill': 'false'
}
return {'datasets': [empty_dataset]}
# Sort all history entries by timestamp
all_histories.sort(key=lambda x: x[1].get('timestamp', 0))
# The master dataset shows the overall progress across all workers
master_dataset = {
'label': 'Total Progress',
'data': [],
'borderColor': '#000000',
'backgroundColor': '#00000020',
'borderWidth': 3,
'fill': 'false',
'tension': 0.1,
'pointRadius': 0
}
# Start with 0 at the earliest timestamp
first_timestamp = all_histories[0][1].get('timestamp', 0)
master_dataset['data'].append({
'x': first_timestamp * 1000, # Convert to milliseconds for Chart.js
'y': 0
})
# Track the last known processed count for each worker/session
latest_processed = {}
# Process all history events in chronological order
for worker_id, entry in all_histories:
timestamp = entry.get('timestamp', 0)
session_id = entry.get('session_id', 'unknown')
key = f"{worker_id}_{session_id}"
# Update the processed count for this worker session
if 'processed_count' in entry:
latest_processed[key] = entry['processed_count']
# Calculate the total processed count across all worker sessions
total_processed = sum(latest_processed.values())
# Add a data point to the master dataset
master_dataset['data'].append({
'x': timestamp * 1000, # Convert to milliseconds for Chart.js
'y': total_processed
})
# Create individual datasets for each worker to show their contribution
worker_datasets = []
colors = [
'#FF6384', '#36A2EB', '#FFCE56', '#4BC0C0', '#9966FF', '#FF9F40'
]
# Group history by worker and session
worker_sessions = {}
for worker_id, entries in self.worker_history.items():
# Group entries by session
sessions = {}
for entry in entries:
session_id = entry.get('session_id', 'unknown')
if session_id not in sessions:
sessions[session_id] = []
sessions[session_id].append(entry)
# Process each session separately
for session_id, session_entries in sessions.items():
key = f"{worker_id}_{session_id}"
worker_sessions[key] = {
'worker_id': worker_id,
'session_id': session_id,
'entries': sorted(session_entries,
key=lambda e: e.get('timestamp', 0))
}
# Create a dataset for each worker session
for i, (key, session) in enumerate(worker_sessions.items()):
worker_id = session['worker_id']
session_id = session['session_id']
entries = session['entries']
if not entries:
continue # Skip empty sessions
# Assign a color
color_idx = sum(ord(c) for c in worker_id) % len(
colors) # Deterministic color based on worker ID
color = colors[color_idx]
# Adjust brightness for different sessions of the same worker
session_num = sum(1 for k in worker_sessions.keys()
if k.startswith(f"{worker_id}_"))
session_idx = sum(1 for k in worker_sessions.keys()
if k.startswith(f"{worker_id}_") and k <= key)
brightness_adjust = (session_idx - 1) * (100 // max(1, session_num))
session_color = self._adjust_color_brightness(
color, brightness_adjust)
# Create the dataset for this worker session
worker_dataset = {
'label': f"{worker_id} (session {session_id[:8]})",
'data': [],
'borderColor': session_color,
'backgroundColor': session_color + '20',
'borderWidth': 1,
'fill': 'false',
'tension': 0.1,
'pointRadius': 1
}
# Start with 0 at the session's first timestamp
first_entry_time = entries[0].get('timestamp', 0)
worker_dataset['data'].append({
'x': first_entry_time * 1000,
'y': 0
})
# Add each progress data point
for entry in entries:
if 'processed_count' in entry:
worker_dataset['data'].append({
'x': entry.get('timestamp', 0) * 1000,
'y': entry['processed_count']
})
worker_datasets.append(worker_dataset)
# Return all datasets, with the master dataset first
return {'datasets': [master_dataset] + worker_datasets}
def _adjust_color_brightness(self, hex_color, percent):
"""Adjust the brightness of a hex color."""
# Convert hex to RGB
r = int(hex_color[1:3], 16)
g = int(hex_color[3:5], 16)
b = int(hex_color[5:7], 16)
# Increase brightness
r = min(255, r + int(percent))
g = min(255, g + int(percent))
b = min(255, b + int(percent))
# Convert back to hex
return f'#{r:02x}{g:02x}{b:02x}'
def get_session_history_data(self) -> Dict:
"""Prepare session history data for visualization."""
session_data = []
for worker_id, history in self.worker_history.items():
# Group events by session
sessions = {}
for event in history:
session_id = event.get('session_id', 'unknown')
if session_id not in sessions:
sessions[session_id] = {
'worker_id': worker_id,
'session_id': session_id,
'start_time': None,
'end_time': None,
'duration': 0,
'processed': 0,
'failed': 0,
'status': 'unknown',
'termination_reason': None
}
# Update session data
timestamp = event.get('timestamp', 0)
if event.get('event') == 'start' or (
not sessions[session_id]['start_time'] and
event.get('timestamp')):
sessions[session_id]['start_time'] = timestamp
if event.get('status') in ['completed', 'failed']:
sessions[session_id]['end_time'] = timestamp
sessions[session_id]['status'] = event.get('status')
# Track spot VM termination events
if event.get('event') == 'termination' or event.get(
'status') == 'terminated':
sessions[session_id]['end_time'] = timestamp
sessions[session_id]['status'] = 'terminated'
sessions[session_id][
'termination_reason'] = 'Spot VM interruption'
if 'processed_count' in event:
sessions[session_id]['processed'] = max(
sessions[session_id]['processed'],
event['processed_count'])
if 'failed_count' in event:
sessions[session_id]['failed'] = max(
sessions[session_id]['failed'], event['failed_count'])
# Calculate duration and add to data
for session in sessions.values():
if session['start_time']:
if session['end_time']:
session['duration'] = session['end_time'] - session[
'start_time']
else:
# Session might still be running
session['duration'] = time.time(
) - session['start_time']
session['status'] = 'running'
session_data.append(session)
return sorted(session_data, key=lambda x: x.get('start_time', 0))
def get_token_throughput_chart_data(self) -> Dict:
"""Prepare token throughput data for Chart.js."""
# Get the earliest and latest timestamps across all workers
all_timestamps = []
for history in self.token_throughput_history.values():
all_timestamps.extend(point['timestamp'] for point in history)
if not all_timestamps:
# Create empty datasets with placeholder data
now = int(time.time())
empty_bar_dataset = {
'label': 'No Token Throughput Data',
'type': 'bar',
'data': [{
'x': (now - 3600) * 1000,
'y': 0
}, {
'x': now * 1000,
'y': 0
}],
'backgroundColor': '#cccccc80',
'borderColor': '#cccccc',
'borderWidth': 1
}
return {'datasets': [empty_bar_dataset]}
min_time = min(all_timestamps)
max_time = max(all_timestamps)
# Generate datasets for token throughput chart
token_datasets = []
# Add a total token throughput dataset that aggregates all workers
total_token_dataset = {
'label': 'Total Token Throughput',
'type': 'bar', # Use bar chart for histogram-like display
'data': [],
'backgroundColor': '#36A2EB80',
'borderColor': '#36A2EB',
'borderWidth': 1,
'barPercentage': 0.8,
'categoryPercentage': 0.9,
'order': 1 # Lower order means it's drawn first (behind other datasets)
}
# Add a line dataset to show the trend
total_token_trend_dataset = {
'label': 'Throughput Trend',
'type': 'line',
'data': [],
'borderColor': '#FF6384',
'backgroundColor': '#FF638420',
'borderWidth': 2,
'pointRadius': 0,
'fill': 'false',
'tension': 0.4,
'order': 0 # Higher priority, drawn on top
}
colors = [
'#FF6384', '#36A2EB', '#FFCE56', '#4BC0C0', '#9966FF', '#FF9F40'
]
# First collect all data points from all workers - gather both recent and overall throughput
all_token_points = []
for worker_id, history in self.token_throughput_history.items():
for point in history:
# Use recent_token_throughput if available, otherwise try overall_token_throughput
throughput = point.get('recent_token_throughput', 0)
if throughput == 0:
throughput = point.get('overall_token_throughput', 0)
all_token_points.append({
'timestamp': point['timestamp'],
'throughput': throughput,
'worker_id': worker_id,
'session_id': point.get('session_id', 'unknown')
})
# Sort all points by timestamp
all_token_points.sort(key=lambda x: x['timestamp'])
# If we have no data points at all, create a placeholder dataset
if not all_token_points:
# Create empty datasets with placeholder data
now = int(time.time())
empty_bar_dataset = {
'label': 'No Token Throughput Data',
'type': 'bar',
'data': [{
'x': (now - 3600) * 1000,
'y': 0
}, {
'x': now * 1000,
'y': 0
}],
'backgroundColor': '#cccccc80',
'borderColor': '#cccccc',
'borderWidth': 1
}
return {'datasets': [empty_bar_dataset]}
# Group data into time bins (e.g., 1-minute bins) for the histogram
bin_size = 60 # 1 minute bins
time_bins = {}
for point in all_token_points:
# Create a bin key by rounding the timestamp to the nearest bin
bin_key = int(point['timestamp'] // bin_size * bin_size)
if bin_key not in time_bins:
time_bins[bin_key] = []
# Only include non-zero values
if point['throughput'] > 0:
time_bins[bin_key].append(point['throughput'])
# Calculate the average throughput for each time bin
for bin_timestamp, throughputs in sorted(time_bins.items()):
if throughputs: # Only process if there are non-zero values
avg_throughput = sum(throughputs) / len(throughputs)
# Add to the histogram dataset
total_token_dataset['data'].append({
'x': bin_timestamp *
1000, # Convert to milliseconds for Chart.js
'y': avg_throughput
})
# Also add to the trend line
total_token_trend_dataset['data'].append({
'x': bin_timestamp * 1000,
'y': avg_throughput
})
# Ensure there's at least one data point for visualization
if not total_token_dataset['data']:
# Create a single data point at the current time with a value of 0
current_time = int(time.time()) * 1000
total_token_dataset['data'].append({'x': current_time, 'y': 0})
total_token_trend_dataset['data'].append({
'x': current_time,
'y': 0
})
# Add individual worker datasets as lines if there aren't too many workers
if len(self.token_throughput_history
) <= 5: # Only show individual workers if there aren't too many
for i, (worker_id, history) in enumerate(
self.token_throughput_history.items()):
# Skip empty history
if not history:
continue
color = colors[i % len(colors)]
# Group by session ID
sessions = {}
for point in history:
session_id = point.get('session_id', 'unknown')
if session_id not in sessions:
sessions[session_id] = []
sessions[session_id].append(point)
# Create separate datasets for each session
for j, (session_id,
session_points) in enumerate(sessions.items()):
# Skip sessions with no data
if not session_points:
continue
# Sort by timestamp
session_points.sort(key=lambda x: x['timestamp'])
# Create dataset for this session
worker_dataset = {
'label': f"{worker_id} (session {session_id[:8]})",
'type': 'line',
'data': [],
'borderColor': self._adjust_color_brightness(
color, j * 20),
'backgroundColor': 'transparent',
'borderWidth': 1,
'pointRadius': 1,
'fill': 'false',
'order': 2 # Draw on top of the histogram
}
# Add data points - try both recent and overall token throughput
for point in session_points:
timestamp = point['timestamp']
# Convert to chart.js time format (milliseconds since epoch)
timestamp_ms = timestamp * 1000
# Use recent_token_throughput if available, otherwise try overall_token_throughput
throughput = point.get('recent_token_throughput', 0)
if throughput == 0:
throughput = point.get('overall_token_throughput',
0)
worker_dataset['data'].append({
'x': timestamp_ms,
'y': throughput
})
# Only add datasets with actual data
if worker_dataset['data']:
token_datasets.append(worker_dataset)
# Combine all datasets, with the total histogram first
all_datasets = [total_token_dataset, total_token_trend_dataset
] + token_datasets
return {'datasets': all_datasets}
def get_dashboard_html(self) -> str:
"""Generate HTML dashboard."""
refresh_rate = 5 # seconds
# Convert metrics to human-readable format
metrics = self.aggregate_metrics.copy()
metrics['overall_progress'] = f"{metrics['overall_progress']:.2f}%"
metrics[
'overall_throughput'] = f"{metrics['overall_throughput']:.2f} requests/s"
metrics[
'recent_throughput'] = f"{metrics['recent_throughput']:.2f} requests/s"
metrics[
'overall_token_throughput'] = f"{metrics['overall_token_throughput']:.2f} tokens/s"
metrics[
'recent_token_throughput'] = f"{metrics['recent_token_throughput']:.2f} tokens/s"
metrics['total_tokens_formatted'] = f"{metrics['total_tokens']:,}"
if metrics['estimated_time_remaining']:
hours = metrics['estimated_time_remaining'] / 3600
metrics['estimated_time_remaining'] = f"{hours:.1f} hours"
else:
metrics['estimated_time_remaining'] = "N/A"
# Generate worker status table
worker_rows = []
for worker_id, worker in self.worker_metrics.items():
progress = (
(worker.get('current_idx', 0) - worker.get('start_idx', 0)) /
(worker.get('end_idx', 1) - worker.get('start_idx', 0)) *
100 if worker.get('end_idx') else 0)
# Count sessions/restarts for this worker
session_count = len(self.worker_sessions.get(worker_id, []))
restart_count = max(0, session_count - 1)
# Determine status class for styling
status_class = ''
if worker.get('status') == 'running':
status_class = 'status-running'
elif worker.get('status') == 'completed':
status_class = 'status-completed'
elif worker.get('status') == 'failed' or worker.get(
'status') == 'terminated':
status_class = 'status-failed'
# Get token throughput if available
token_throughput = "N/A"
if worker_id in self.token_throughput_history and self.token_throughput_history[
worker_id]:
latest_token = sorted(self.token_throughput_history[worker_id],
key=lambda x: x['timestamp'])[-1]
token_throughput = f"{latest_token.get('recent_token_throughput', 0):.2f} tokens/s"
row = f"""
<tr>
<td>{worker_id}</td>
<td class="{status_class}">{worker.get('status', 'unknown')}</td>
<td>{progress:.2f}%</td>
<td>{worker.get('processed_count', 0)}</td>
<td>{worker.get('failed_count', 0)}</td>
<td>{worker.get('items_per_second', 0):.2f}</td>
<td>{token_throughput}</td>
<td>{restart_count}</td>
<td>{datetime.fromtimestamp(worker.get('timestamp', worker.get('last_update', 0))).strftime('%Y-%m-%d %H:%M:%S')}</td>
</tr>
"""
worker_rows.append(row)
# Generate session history table
session_data = self.get_session_history_data()
session_rows = []
for session in session_data:
start_time = datetime.fromtimestamp(session.get(
'start_time', 0)).strftime('%Y-%m-%d %H:%M:%S') if session.get(
'start_time') else 'N/A'
end_time = datetime.fromtimestamp(session.get(
'end_time', 0)).strftime('%Y-%m-%d %H:%M:%S') if session.get(
'end_time') else 'N/A'
duration_mins = session.get('duration', 0) / 60
# Determine status class for styling
status_class = ''
if session.get('status') == 'running':
status_class = 'status-running'
elif session.get('status') == 'completed':
status_class = 'status-completed'
elif session.get('status') in ['failed', 'terminated']:
status_class = 'status-failed'
# Show termination reason for spot VM interruptions
status_text = session.get('status', 'unknown')
if session.get('termination_reason'):
status_text += f" ({session.get('termination_reason')})"
row = f"""
<tr>
<td>{session.get('worker_id', 'unknown')}</td>
<td>{session.get('session_id', 'unknown')[-8:]}</td>
<td class="{status_class}">{status_text}</td>
<td>{start_time}</td>
<td>{end_time if session.get('end_time') else 'Running'}</td>
<td>{duration_mins:.1f} min</td>
<td>{session.get('processed', 0)}</td>
<td>{session.get('failed', 0)}</td>
</tr>
"""
session_rows.append(row)
# Get chart data
chart_data = json.dumps(self.get_throughput_chart_data())
# Add token throughput chart initialization
token_throughput_data = self.get_token_throughput_chart_data()
token_throughput_chart_json = json.dumps(token_throughput_data)
# Initialize charts
charts_js = f"""
// Throughput Chart (Cumulative Progress)
var throughputCtx = document.getElementById('throughputChart').getContext('2d');
var throughputChart = new Chart(throughputCtx, {{
type: 'line',
data: {chart_data},
options: {{
responsive: true,
maintainAspectRatio: 'false',
scales: {{
x: {{
type: 'time',
time: {{
unit: 'minute',
tooltipFormat: 'MMM dd, HH:mm:ss',
displayFormats: {{
minute: 'HH:mm'
}}
}},
title: {{
display: true,
text: 'Time'
}}
}},
y: {{
title: {{
display: true,
text: 'Cumulative Items Processed'
}},
beginAtZero: true,
min: 0 // Force the minimum to be exactly 0
}}
}},
plugins: {{
title: {{
display: true,
text: 'Cumulative Progress Over Time'
}},
tooltip: {{
callbacks: {{
label: function(context) {{
var label = context.dataset.label || '';
if (label) {{
label += ': ';
}}
if (context.parsed.y !== null) {{
label += context.parsed.y.toFixed(0) + ' items';
}}
return label;
}}
}}
}}
}}
}}
}});
// Token Throughput Chart (Histogram)
var tokenThroughputCtx = document.getElementById('tokenThroughputChart').getContext('2d');
var tokenThroughputData = {token_throughput_chart_json};
// Debug output to check if data is empty
console.log("Token Throughput Data:", tokenThroughputData);
var tokenThroughputChart = new Chart(tokenThroughputCtx, {{
type: 'bar', // Default type is bar (histogram)
data: tokenThroughputData,
options: {{
responsive: true,
maintainAspectRatio: 'false',
scales: {{
x: {{
type: 'time',
time: {{
unit: 'minute',
tooltipFormat: 'MMM dd, HH:mm:ss',
displayFormats: {{
minute: 'HH:mm'
}}
}},
title: {{
display: true,
text: 'Time'
}}
}},
y: {{
title: {{
display: true,
text: 'Tokens/second'
}},
beginAtZero: true,
min: 0 // Force the minimum to be exactly 0
}}
}},
plugins: {{
title: {{
display: true,
text: 'Token Throughput Over Time'
}},
tooltip: {{
callbacks: {{
label: function(context) {{
var label = context.dataset.label || '';
if (label) {{
label += ': ';
}}
if (context.parsed.y !== null) {{
label += context.parsed.y.toFixed(2) + ' tokens/s';
}}
return label;
}}
}}
}}
}}
}}
}});
"""
return f"""
<!DOCTYPE html>
<html>
<head>
<title>CLIP Vector Computation Progress</title>
<meta id="refresh-meta" http-equiv="refresh" content="{refresh_rate}">
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]"></script>
<script>
// Enable mixed chart types (bar + line)
Chart.defaults.set('plugins.legend', {{
position: 'top',
labels: {{
usePointStyle: true,
padding: 15
}}
}});
</script>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.metrics-grid {{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 20px;
margin-bottom: 30px;
}}
.metric-card {{
background: #f5f5f5;
padding: 15px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
.metric-value {{ font-size: 24px; font-weight: bold; margin: 10px 0; }}
.metric-label {{ color: #666; }}
table {{
width: 100%;
border-collapse: collapse;
margin-top: 20px;
}}
th, td {{
padding: 12px;
text-align: left;
border-bottom: 1px solid #ddd;
}}
th {{ background-color: #f5f5f5; }}
.chart-container {{
width: 100%;
height: 400px;
margin: 30px 0;
}}
.status-running {{ color: green; }}
.status-completed {{ color: blue; }}
.status-failed {{ color: red; }}
.tabs {{
margin-top: 30px;
border-bottom: 1px solid #ccc;
display: flex;
}}
.tab {{
padding: 10px 15px;
cursor: pointer;
background: #f5f5f5;
margin-right: 5px;
border-radius: 5px 5px 0 0;
}}
.tab.active {{
background: #ddd;
}}
.tab-content {{
display: none;
padding: 20px 0;
}}
.tab-content.active {{
display: block;
}}
.toggle-refresh {{
margin-top: 10px;
padding: 8px 15px;
background-color: #4CAF50;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
}}
.toggle-refresh.paused {{
background-color: #f44336;
}}
.header-controls {{
display: flex;
justify-content: space-between;
align-items: center;
}}
.refresh-status {{
margin-left: 10px;
font-size: 14px;
color: #666;
}}
</style>
</head>
<body>
<div class="header-controls">
<h1>CLIP Vector Computation Progress</h1>
<div>
<button id="toggle-refresh" class="toggle-refresh">Pause Auto-Refresh</button>
<span id="refresh-status" class="refresh-status">Auto-refreshing every {refresh_rate}s</span>
</div>
</div>
<div class="metrics-grid">
<div class="metric-card">
<div class="metric-label">Overall Progress</div>
<div class="metric-value">{metrics['overall_progress']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Overall Speed</div>
<div class="metric-value">{metrics['overall_throughput']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Recent Speed</div>
<div class="metric-value">{metrics['recent_throughput']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Estimated Time Remaining</div>
<div class="metric-value">{metrics['estimated_time_remaining']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Total Processed</div>
<div class="metric-value">{metrics['total_processed']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Failed Requests</div>
<div class="metric-value">{metrics['total_failed']}</div>
</div>
<div class="metric-card">
<div class="metric-label">VM Restarts</div>
<div class="metric-value">{metrics['total_restarts']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Total Tokens</div>
<div class="metric-value">{metrics['total_tokens_formatted']}</div>
</div>
<div class="metric-card">
<div class="metric-label">Token Throughput</div>
<div class="metric-value">{metrics['overall_token_throughput']}</div>
</div>
</div>
<h2>Cumulative Progress</h2>
<div class="chart-container">
<canvas id="throughputChart"></canvas>
</div>
<h2>Token Throughput</h2>
<div class="chart-container">
<canvas id="tokenThroughputChart"></canvas>
</div>
<div class="tabs">
<div class="tab active" onclick="showTab('currentStatus')">Current Status</div>
<div class="tab" onclick="showTab('sessionHistory')">Session History</div>
</div>
<div id="currentStatus" class="tab-content active">
<h2>Worker Status</h2>
<table>
<thead>
<tr>
<th>Worker ID</th>
<th>Status</th>
<th>Progress</th>
<th>Processed</th>
<th>Failed</th>
<th>Speed (requests/s)</th>
<th>Token Speed</th>
<th>Restarts</th>
<th>Last Update</th>
</tr>
</thead>
<tbody>
{''.join(worker_rows)}
</tbody>
</table>
</div>
<div id="sessionHistory" class="tab-content">
<h2>Session History</h2>
<table>
<thead>
<tr>
<th>Worker ID</th>
<th>Session ID</th>
<th>Status</th>
<th>Start Time</th>
<th>End Time</th>
<th>Duration</th>
<th>Processed</th>
<th>Failed</th>
</tr>
</thead>
<tbody>
{''.join(session_rows)}
</tbody>
</table>
</div>
<script>
// Chart setup
const ctx = document.getElementById('throughputChart');
const chartData = {chart_data};
const chart = new Chart(ctx, {{
type: 'line',
data: chartData,
options: {{
responsive: true,
maintainAspectRatio: 'false',
scales: {{
x: {{
type: 'time',
time: {{
unit: 'minute',
tooltipFormat: 'MMM dd, HH:mm:ss'
}},
title: {{
display: true,
text: 'Time'
}}
}},
y: {{
beginAtZero: true,
min: 0, // Force the minimum to be exactly 0
title: {{
display: true,
text: 'Cumulative Items Processed'
}}
}}
}},
plugins: {{
title: {{
display: true,
text: 'Cumulative Progress Over Time'
}},
tooltip: {{
callbacks: {{
label: function(context) {{
var label = context.dataset.label || '';
if (label) {{
label += ': ';
}}
if (context.parsed.y !== null) {{
label += context.parsed.y.toFixed(0) + ' items';
}}
return label;
}}
}}
}}
}}
}}
}});
// Tab switching function
function showTab(tabId) {{
// Hide all tab contents
document.querySelectorAll('.tab-content').forEach(content => {{
content.classList.remove('active');
}});
// Deactivate all tabs
document.querySelectorAll('.tab').forEach(tab => {{
tab.classList.remove('active');
}});
// Show the selected tab content
document.getElementById(tabId).classList.add('active');
// Activate the clicked tab
event.currentTarget.classList.add('active');
}}
// Auto-refresh toggle
let autoRefreshEnabled = true;
const toggleButton = document.getElementById('toggle-refresh');
const refreshStatus = document.getElementById('refresh-status');
const refreshMeta = document.getElementById('refresh-meta');
toggleButton.addEventListener('click', function() {{
autoRefreshEnabled = !autoRefreshEnabled;
if (autoRefreshEnabled) {{
// Enable auto-refresh
refreshMeta.setAttribute('content', '{refresh_rate}');
toggleButton.textContent = 'Pause Auto-Refresh';
toggleButton.classList.remove('paused');
refreshStatus.textContent = 'Auto-refreshing every {refresh_rate}s';
}} else {{
// Disable auto-refresh
refreshMeta.setAttribute('content', '');
toggleButton.textContent = 'Resume Auto-Refresh';
toggleButton.classList.add('paused');
refreshStatus.textContent = 'Auto-refresh paused';
}}
}});
// Manual refresh button
document.addEventListener('keydown', function(event) {{
// If F5 or Ctrl+R is pressed and auto-refresh is disabled, don't prevent refresh
if (!autoRefreshEnabled) {{
return;
}}
// Prevent F5 or Ctrl+R from refreshing when auto-refresh is enabled
if (event.key === 'F5' || (event.ctrlKey && event.key === 'r')) {{
event.preventDefault();
alert('Auto-refresh is enabled. To manually refresh, first click "Pause Auto-Refresh".');
}}
}});
</script>
</body>
</html>
"""
monitoring_service = None
@app.on_event("startup")
async def startup_event():
global monitoring_service
metrics_dir = "/output/metrics" # This should match the directory in compute_vectors.py
monitoring_service = MonitoringService(metrics_dir)
# Load initial data immediately
await monitoring_service.update_metrics()
# Start background task to update metrics
asyncio.create_task(periodic_metrics_update())
async def periodic_metrics_update():
while True:
await monitoring_service.update_metrics()
await asyncio.sleep(5) # Update every 5 seconds
@app.get("/", response_class=HTMLResponse)
async def get_dashboard():
if not monitoring_service:
raise HTTPException(status_code=503,
detail="Monitoring service not initialized")
return monitoring_service.get_dashboard_html()
@app.get("/api/metrics")
async def get_metrics():
if not monitoring_service:
raise HTTPException(status_code=503,
detail="Monitoring service not initialized")
return {
"aggregate_metrics": monitoring_service.aggregate_metrics,
"worker_metrics": monitoring_service.worker_metrics
}
def main():
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
scripts/text_vector_processor.py
import asyncio
import json
import logging
import os
from pathlib import Path
import pickle
import time
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
from base_vector_processor import BaseVectorProcessor
import nltk
import numpy as np
import pandas as pd
import requests
from tqdm import tqdm
# Initialize NLTK for text chunking
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
logging.info('Downloading NLTK punkt tokenizer...')
nltk.download('punkt')
nltk.download('punkt_tab')
logging.info('Download complete')
class TextVectorProcessor(BaseVectorProcessor):
"""Process text data to compute vector embeddings."""
def __init__(self,
output_path: str,
vllm_endpoint: str,
model_name: str = "Alibaba-NLP/gte-Qwen2-7B-instruct",
dataset_name: str = 'pile-of-law/pile-of-law',
dataset_config: str = 'all',
split: str = 'full',
streaming: bool = True,
batch_size: int = 32,
checkpoint_size: int = 100,
start_idx: int = 0,
end_idx: Optional[int] = None,
chunk_size: int = 512,
chunk_overlap: int = 50,
max_preprocessing_tasks: int = 10,
partition_method: str = 'chunk',
worker_rank: int = 0,
total_workers: int = 1,
global_start_idx: int = 0,
global_end_idx: Optional[int] = None):
"""Initialize the text vector processor.
Args:
output_path: Path to save the computed vectors
vllm_endpoint: Endpoint for vLLM service
model_name: Name of the model to use for embeddings
dataset_name: Name of the dataset to process
dataset_config: Dataset configuration
split: Dataset split to use
streaming: Whether to stream the dataset
batch_size: Size of batches for processing
checkpoint_size: Number of items to process before saving
start_idx: Starting index in the dataset
end_idx: Ending index in the dataset
chunk_size: Size of document chunks
chunk_overlap: Overlap between chunks
max_preprocessing_tasks: Maximum number of concurrent preprocessing tasks
partition_method: Method of partitioning ('chunk' or 'stride')
worker_rank: Rank of this worker (0-based)
total_workers: Total number of workers
global_start_idx: Global starting index for all workers
global_end_idx: Global ending index for all workers
"""
# If using stride method, adjust start_idx and end_idx
self.partition_method = partition_method
self.worker_rank = worker_rank
self.total_workers = total_workers
self.global_start_idx = global_start_idx
self.global_end_idx = global_end_idx
# For stride method, we'll handle the actual skipping in get_dataset_iterator
# but we keep the original range for BaseVectorProcessor
super().__init__(output_path=output_path,
dataset_name=dataset_name,
split=split,
streaming=streaming,
batch_size=batch_size,
checkpoint_size=checkpoint_size,
start_idx=start_idx,
end_idx=end_idx,
max_preprocessing_tasks=max_preprocessing_tasks)
# Text-specific attributes
self.vllm_endpoint = vllm_endpoint
self.model_name = model_name
self.dataset_config = dataset_config
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.next_chunk_idx = 0 # Global chunk counter
self.chunks = [] # Store preprocessed chunks
# Token tracking attributes
self.total_tokens = 0
self.batch_count = 0
self.token_metrics = {
'total_tokens': 0,
'avg_tokens_per_batch': 0,
'avg_tokens_per_chunk': 0,
'token_count_by_batch': {}
}
# Log partitioning method
if self.partition_method == 'stride':
logging.info(
f"Using strided partitioning: worker {self.worker_rank} of {self.total_workers}, "
f"processing every {self.total_workers}th item starting from {self.global_start_idx}"
)
else:
logging.info(
f"Using chunk partitioning: processing items from {self.start_idx} to {self.end_idx}"
)
async def setup_model(self):
"""Verify vLLM endpoint is accessible."""
try:
response = requests.get(f"{self.vllm_endpoint}/health", timeout=10)
if response.status_code == 200:
logging.info(
f"Successfully connected to vLLM endpoint: {self.vllm_endpoint}"
)
self.model = True # Just a flag to indicate setup is complete
else:
raise ConnectionError(
f"vLLM endpoint returned status code: {response.status_code}"
)
except Exception as e:
logging.error(f"Failed to connect to vLLM endpoint: {str(e)}")
raise
async def get_dataset_iterator(self) -> AsyncIterator[Tuple[int, Any]]:
"""Load data from a HuggingFace dataset."""
from datasets import load_dataset
dataset = load_dataset(self.dataset_name,
self.dataset_config,
split=self.split,
streaming=self.streaming,
trust_remote_code=True)
# Handle different partitioning methods
if self.partition_method == 'stride':
# For stride method, we process every Nth item where N is total_workers
# starting from global_start_idx + worker_rank
start_point = self.global_start_idx + self.worker_rank
item_counter = 0
for idx, item in enumerate(dataset, start=0):
if idx < start_point:
continue
# Only process items that belong to this worker based on the stride
if (idx - start_point) % self.total_workers == 0:
# Check global end condition
if self.global_end_idx and idx >= self.global_end_idx:
break
# Transform item into a document format
document = {
'id': f"{idx}",
'name': item.get('url', f"document_{idx}"),
'text': item['text'],
'split': self.split,
'source': item.get('source', self.dataset_name),
'created_timestamp': item.get('created_timestamp',
None),
'downloaded_timestamp': item.get(
'downloaded_timestamp', None),
'url': item.get('url', None)
}
yield idx, document
item_counter += 1
# Provide some logging feedback
if item_counter % 100 == 0:
logging.info(
f"Worker {self.worker_rank}: Processed {item_counter} items (global idx {idx})"
)
else:
# Original chunk behavior
if self.start_idx > 0:
dataset = dataset.skip(self.start_idx)
for idx, item in enumerate(dataset, start=self.start_idx):
if self.end_idx and idx >= self.end_idx:
break
# Transform item into a document format
document = {
'id': f"{idx}",
'name': item.get('url', f"document_{idx}"),
'text': item['text'],
'split': self.split,
'source': item.get('source', self.dataset_name),
'created_timestamp': item.get('created_timestamp', None),
'downloaded_timestamp': item.get('downloaded_timestamp',
None),
'url': item.get('url', None)
}
yield idx, document
async def _preprocess_input(self, document: Dict) -> Optional[List[Dict]]:
"""Chunk a document into smaller pieces."""
try:
doc_chunks = self.chunk_document(document)
if doc_chunks:
self.chunks.extend(doc_chunks)
return doc_chunks
except Exception as e:
self.failed_count += 1
self.processed_count += 1 # Count failed items as processed
logging.debug(
f"Error preprocessing document {document['id']}: {str(e)}")
return None
def chunk_document(self, document: Dict) -> List[Dict]:
"""Chunk a document into smaller pieces.
Args:
document: Document to chunk
Returns:
List of chunks
"""
from nltk.tokenize import sent_tokenize
doc_id = document['id']
text = document['text']
# Skip empty documents
if not text or not text.strip():
return []
# Split text into sentences
try:
sentences = sent_tokenize(text)
except Exception as e:
logging.warning(f"Error tokenizing document {doc_id}: {str(e)}")
# Fallback to simple splitting
sentences = text.split('. ')
# Initialize chunks
chunks = []
current_chunk = []
current_size = 0
# Track token positions for each sentence
for sentence in sentences:
# Skip empty sentences
if not sentence.strip():
continue
sentence_tokens = sentence.split()
sentence_size = len(sentence_tokens)
# If sentence is too big for a single chunk, split it further
if sentence_size > self.chunk_size:
if current_chunk:
# Save current chunk before processing long sentence
chunk_id = f"{doc_id}_chunk_{self.next_chunk_idx}"
chunk_text = ' '.join(current_chunk)
chunks.append({
'id': chunk_id,
'name': document['name'],
'document_id': doc_id,
'chunk_text': chunk_text,
'content': chunk_text, # For compatibility with embedding API
'chunk_start': current_size - len(current_chunk),
'split': document['split'],
'source': document['source'],
'document_url': document.get('url'),
'document_created_timestamp':
document.get('created_timestamp'),
'document_downloaded_timestamp':
document.get('downloaded_timestamp')
})
self.next_chunk_idx += 1
current_chunk = []
# Split long sentence into multiple chunks
for i in range(0, sentence_size, self.chunk_size):
sub_sentence = ' '.join(sentence_tokens[i:i +
self.chunk_size])
chunk_id = f"{doc_id}_chunk_{self.next_chunk_idx}"
chunks.append({
'id': chunk_id,
'name': document['name'],
'document_id': doc_id,
'chunk_text': sub_sentence,
'content': sub_sentence,
'chunk_start': current_size + i,
'split': document['split'],
'source': document['source'],
'document_url': document.get('url'),
'document_created_timestamp':
document.get('created_timestamp'),
'document_downloaded_timestamp':
document.get('downloaded_timestamp')
})
self.next_chunk_idx += 1
current_size += sentence_size
continue
# If adding this sentence would exceed chunk size, start a new chunk
if current_size + sentence_size > self.chunk_size and current_chunk:
chunk_id = f"{doc_id}_chunk_{self.next_chunk_idx}"
chunk_text = ' '.join(current_chunk)
chunks.append({
'id': chunk_id,
'name': document['name'],
'document_id': doc_id,
'chunk_text': chunk_text,
'content': chunk_text,
'chunk_start': current_size - len(current_chunk),
'split': document['split'],
'source': document['source'],
'document_url': document.get('url'),
'document_created_timestamp':
document.get('created_timestamp'),
'document_downloaded_timestamp':
document.get('downloaded_timestamp')
})
self.next_chunk_idx += 1
# Keep some overlap for context
overlap_tokens = min(self.chunk_overlap, len(current_chunk))
current_chunk = current_chunk[
-overlap_tokens:] if overlap_tokens > 0 else []
# Add sentence to current chunk
current_chunk.append(sentence)
current_size += sentence_size
# Don't forget the last chunk
if current_chunk:
chunk_id = f"{doc_id}_chunk_{self.next_chunk_idx}"
chunk_text = ' '.join(current_chunk)
chunks.append({
'id': chunk_id,
'name': document['name'],
'document_id': doc_id,
'chunk_text': chunk_text,
'content': chunk_text,
'chunk_start': current_size - len(current_chunk),
'split': document['split'],
'source': document['source'],
'document_url': document.get('url'),
'document_created_timestamp': document.get('created_timestamp'),
'document_downloaded_timestamp':
document.get('downloaded_timestamp')
})
self.next_chunk_idx += 1
return chunks
async def do_batch_processing(
self, batch: List[Tuple[int,
List[Dict]]]) -> List[Tuple[int, Dict]]:
"""Process a batch of document chunks."""
if self.model is None:
await self.setup_model()
# Flatten the chunks from all documents
all_chunks = []
for _, chunks in batch:
if chunks:
all_chunks.extend(chunks)
if not all_chunks:
return []
results = []
# Process in smaller batches to avoid API limits
for i in range(0, len(all_chunks), self.batch_size):
batch_chunks = all_chunks[i:i + self.batch_size]
# Get text for each chunk
prompts = [chunk['content'] for chunk in batch_chunks]
try:
# Create request payload
request_payload = {
"model": self.model_name,
"input": prompts,
"encoding_format": "float"
}
# Send request to vLLM service
response = requests.post(f"{self.vllm_endpoint}/v1/embeddings",
json=request_payload,
timeout=60)
response.raise_for_status()
# Extract embeddings
result = response.json()
if 'data' not in result:
raise ValueError(f"Unexpected response format: {result}")
embeddings = [item['embedding'] for item in result['data']]
# Extract token counts if available
batch_token_count = 0
if 'usage' in result:
batch_token_count = result['usage'].get('total_tokens', 0)
self.total_tokens += batch_token_count
self.batch_count += 1
# Update token metrics
self.token_metrics['total_tokens'] = self.total_tokens
self.token_metrics[
'avg_tokens_per_batch'] = self.total_tokens / self.batch_count
self.token_metrics[
'avg_tokens_per_chunk'] = self.total_tokens / (
self.batch_count * len(batch_chunks))
self.token_metrics['token_count_by_batch'][str(
self.batch_count)] = batch_token_count
# Log token usage
logging.info(
f"Batch {self.batch_count} token count: {batch_token_count}, "
f"Total tokens: {self.total_tokens}, "
f"Avg tokens per batch: {self.token_metrics['avg_tokens_per_batch']:.2f}"
)
# Combine embeddings with metadata
for chunk, embedding in zip(batch_chunks, embeddings):
# Find the document index this chunk belongs to
doc_id = chunk['document_id']
chunk_id = chunk['id']
results.append((int(doc_id), {
'id': chunk_id,
'document_id': doc_id,
'name': chunk['name'],
'content': chunk['content'],
'chunk_text': chunk['chunk_text'],
'chunk_start': chunk['chunk_start'],
'split': chunk['split'],
'source': chunk['source'],
'embedding': np.array(embedding),
'document_url': chunk.get('document_url'),
'document_created_timestamp':
chunk.get('document_created_timestamp'),
'document_downloaded_timestamp':
chunk.get('document_downloaded_timestamp'),
'token_count': batch_token_count // len(batch_chunks)
if batch_token_count > 0 else 0
}))
self.processed_count += 1
except Exception as e:
logging.error(f"Error computing embeddings for batch: {str(e)}")
time.sleep(5)
# We'll count these items as processed but failed
for chunk in batch_chunks:
self.failed_count += 1
self.processed_count += 1
return results
def update_metrics(self):
"""Override update_metrics to include token statistics."""
# Call the parent class method first
super().update_metrics()
# Add token metrics to the most recent metrics
if self.metrics_history and self.total_tokens > 0:
# Update the most recent metrics entry with token information
self.metrics_history[-1].update({
'total_tokens': self.total_tokens,
'avg_tokens_per_batch':
self.token_metrics['avg_tokens_per_batch'],
'avg_tokens_per_chunk':
self.token_metrics['avg_tokens_per_chunk'],
'tokens_per_second':
self.total_tokens /
self.metrics_history[-1]['elapsed_seconds']
if self.metrics_history[-1]['elapsed_seconds'] > 0 else 0
})
# Save the updated metrics
try:
with open(self.metrics_file, 'w') as f:
json.dump(self.metrics_history[-1], f)
self.save_metrics_history()
except Exception as e:
logging.warning(f"Failed to save metrics with token stats: {e}")
def save_results_to_parquet(self, results: List[Tuple[int, Dict]]):
"""Save results to a parquet file with partition."""
if not results:
return
# Extract results to DataFrames
embeddings_data = []
for _, item in results:
embedding_bytes = pickle.dumps(item['embedding'])
embeddings_data.append({
'id': item['id'],
'document_id': item['document_id'],
'name': item['name'],
'content': item['content'],
'chunk_text': item['chunk_text'],
'chunk_start': item['chunk_start'],
'split': item['split'],
'source': item['source'],
'embedding': embedding_bytes,
'document_url': item.get('document_url'),
'document_created_timestamp':
item.get('document_created_timestamp'),
'document_downloaded_timestamp':
item.get('document_downloaded_timestamp'),
'token_count': item.get('token_count', 0) # Include token count
})
# Create DataFrame
df = pd.DataFrame(embeddings_data)
# Create the output directory if it doesn't exist
os.makedirs(self.output_path.parent, exist_ok=True)
# Generate the partition output path
base_name = self.output_path.stem
output_path = self.output_path.parent / f"{base_name}_part{self.partition_counter}.parquet"
# Save to parquet
df.to_parquet(output_path)
logging.info(f"Saved {len(df)} embeddings to {output_path}")
logging.info(
f"Token metrics: total={self.total_tokens}, avg_per_batch={self.token_metrics['avg_tokens_per_batch']:.2f}"
)
# Update metrics after saving results
self.update_metrics()
# Increment partition counter
self.partition_counter += 1
async def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description='Compute text embeddings')
parser.add_argument('--output-path',
type=str,
required=True,
help='Path to save the output parquet file')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Starting index in the dataset')
parser.add_argument('--end-idx',
type=int,
default=1000,
help='Ending index in the dataset')
parser.add_argument('--chunk-size',
type=int,
default=512,
help='Size of document chunks')
parser.add_argument('--chunk-overlap',
type=int,
default=50,
help='Overlap between chunks')
parser.add_argument('--batch-size',
type=int,
default=32,
help='Batch size for processing')
parser.add_argument('--checkpoint-size',
type=int,
default=100,
help='Number of items to process before saving')
parser.add_argument('--vllm-endpoint',
type=str,
required=True,
help='Endpoint for vLLM service')
parser.add_argument('--model-name',
type=str,
default='Alibaba-NLP/gte-Qwen2-7B-instruct',
help='Model name')
parser.add_argument('--dataset-name',
type=str,
default='pile-of-law/pile-of-law',
help='HuggingFace dataset name')
parser.add_argument('--dataset-config',
type=str,
default='all',
help='Dataset configuration')
parser.add_argument(
'--partition-method',
type=str,
choices=['chunk', 'stride'],
default=os.environ.get('PARTITION_METHOD', 'stride'),
help=
'Method to partition data: chunk (contiguous) or stride (interleaved)')
parser.add_argument('--worker-rank',
type=int,
default=int(os.environ.get('WORKER_RANK', 0)),
help='Rank of this worker (0-based)')
parser.add_argument('--total-workers',
type=int,
default=int(os.environ.get('TOTAL_WORKERS', 1)),
help='Total number of workers')
parser.add_argument('--global-start-idx',
type=int,
default=int(os.environ.get('GLOBAL_START_IDX', 0)),
help='Global starting index for all workers')
parser.add_argument('--global-end-idx',
type=int,
default=int(os.environ.get('GLOBAL_END_IDX', 0)) or
None,
help='Global ending index for all workers')
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Initialize processor
processor = TextVectorProcessor(output_path=args.output_path,
vllm_endpoint=args.vllm_endpoint,
start_idx=args.start_idx,
end_idx=args.end_idx,
batch_size=args.batch_size,
checkpoint_size=args.checkpoint_size,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
model_name=args.model_name,
dataset_name=args.dataset_name,
dataset_config=args.dataset_config,
partition_method=args.partition_method,
worker_rank=args.worker_rank,
total_workers=args.total_workers,
global_start_idx=args.global_start_idx,
global_end_idx=args.global_end_idx)
# Run processing
await processor.run()
if __name__ == '__main__':
asyncio.run(main())