Source: llm/vicuna-llama-2
Train Your Own Vicuna on Llama-2#
Meta released Llama 2 two weeks ago and has made a big wave in the AI community. In our opinion, its biggest impact is that the model is now released under a permissive license that allows the model weights to be used commercially[1]. This differs from Llama 1 which cannot be used commercially.
Vicuna is one of the first high-quality LLMs finetuned on Llama 1. We, Vicuna’s co-creators, updated the exact recipe that we used to train Vicuna to be based on Llama 2 instead, producing this finetuning guide.
In this recipe, we will show how to train your own Vicuna on Llama 2, using SkyPilot to easily find available GPUs on the cloud, while reducing costs to only ~$300.
Prerequisites#
Apply for access to the Llama-2 model
Go to the application page and apply for access to the model weights.
Get an access token from HuggingFace
Generate a read-only access token on HuggingFace here. Go to the HuggingFace page for Llama-2 models here and apply for access. Ensure your HuggingFace email is the same as the email on the Meta request. It may take 1-2 days for approval.
Download the recipe
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/vicuna-llama-2
Paste the access token into train.yaml:
envs:
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
Train your own Vicuna on Llama-2#
Training data and model identity#
By default, we use the ShareGPT data and the identity questions in hardcoded_questions.py.
Optional: To use custom data, you can change the following line in train.yaml:
setup: |
...
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
...
The above json file is an array, each element of which having the following format (the conversation can have multiple turns, between human
and gpt
):
{
"id": "i6IyJda_0",
"conversations": [
{
"from": "human",
"value": "How to tell if a customer segment is well segmented? In 3 bullet points."
},
{
"from": "gpt",
"value": "1. Homogeneity: The segment should consist of customers who share similar characteristics and behaviors.\n2. Distinctiveness: The segment should be different from other segments in terms of their characteristics and behaviors.\n3. Stability: The segment should remain relatively stable over time and not change drastically. The characteristics and behaviors of customers within the segment should not change significantly."
}
]
},
Optional: To make the model know about its identity, you can change the hardcoded questions hardcoded_questions.py
Note: Models trained on ShareGPT data may have restrictions on commercial usage. Swap it out with your own data for commercial use.
Kick start training on any cloud#
Start training with a single command
sky launch --down -c vicuna train.yaml \
--env ARTIFACT_BUCKET_NAME=<your-bucket-name> \
--env WANDB_API_KEY=<your-wandb-api-key>
This will launch the training job on the cheapest cloud that has 8x A100-80GB spot GPUs available.
Tip: You can get
WANDB_API_KEY
at https://wandb.ai/settings. To disable Weights & Biases, simply leave out that--env
flag.
Tip: You can set
ARTIFACT_BUCKET_NAME
to a new bucket name, such as<whoami>-tmp-bucket
, and SkyPilot will create the bucket for you.
Use on-demand instead to unlock more clouds: Inside train.yaml
we requested using spot instances:
resources:
accelerators: A100-80GB:8
disk_size: 1000
use_spot: true
However, spot A100-80GB:8 is currently only supported on GCP. On-demand versions are supported on AWS, Azure, GCP, Lambda, and more. (Hint: check out the handy outputs of sky show-gpus A100-80GB:8
!)
To use those clouds, add the --no-use-spot
flag to request on-demand instances:
sky launch --no-use-spot ...
Optional: Try out the training for the 13B model:
sky launch -c vicuna train.yaml \
--env ARTIFACT_BUCKET_NAME=<your-bucket-name> \
--env WANDB_API_KEY=<your-wandb-api-key> \
--env MODEL_SIZE=13
Reducing costs by 3x with spot instances#
SkyPilot Managed Jobs is a library built on top of SkyPilot that helps users run jobs on spot instances without worrying about interruptions. That is the tool used by the LMSYS organization to train the first version of Vicuna (more details can be found in their launch blog post and example). With this, the training cost can be reduced from $1000 to $300.
To use SkyPilot Managed Spot Jobs, you can simply replace sky launch
with sky jobs launch
in the above command:
sky jobs launch -n vicuna train.yaml \
--env ARTIFACT_BUCKET_NAME=<your-bucket-name> \
--env WANDB_API_KEY=<your-wandb-api-key>
Serve your model#
After the training is done, you can serve your model in your own cloud environment with a single command:
sky launch -c serve serve.yaml --env MODEL_CKPT=<your-model-checkpoint>/chatbot/7b
In serve.yaml, we specified launching a Gradio server that serves the model checkpoint at <your-model-checkpoint>/chatbot/7b
.
Tip: You can also switch to a cheaper accelerator, such as L4, to save costs, by adding
--gpus L4
to the above command.
Included files#
scripts/flash_attn_patch.py
import logging
from typing import List, Optional, Tuple
from einops import rearrange
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
# pip3 install "flash-attn>=2.0"
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
import torch
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
key_states = (self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
value_states = (self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states],
dim=2) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device)
output = flash_attn_varlen_qkvpacked_func(qkv,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True)
output = rearrange(
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices,
bsz, q_len),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
logging.warning(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
scripts/hardcoded_questions.py
# This code is based on lmsys-org/fastchat. Below is the original copyright:
#
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import textwrap
def identity_questions():
""" "
Adopted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py
"""
content = []
name = "SkyPilot-Vicuna"
org = "UC Berkeley Sky Computing Lab and SkyPilot Team"
base = "Llama-2"
def generate_conversations(questions, answers):
for q in questions:
for a in answers:
content.append({
"id": f"identity_{len(content)}",
"conversations": [
{
"from": "human",
"value": q
},
{
"from": "gpt",
"value": a
},
],
})
questions = [
"Who are you?",
"What is your name?",
"Can you introduce yourself?",
"Can you tell me a little bit about yourself?",
"What's your name?",
"What are you called?",
"What are you?",
"Tell me your name.",
"Tell me about yourself.",
"Tell me about you.",
"Tell me who you are.",
"Please introduce yourself.",
]
answers = [
f"I am {name}, a language model trained on {base} by researchers from {org} using SkyPilot.",
f"My name is {name}, and I'm a language model based on {base} developed by {org} using SkyPilot.",
f"You can call me {name}, and I was trained by {org} researchers as a language model based on {base} using SkyPilot.",
f"As a language model, I go by the name {name} and was trained on {base} by researchers from {org} using SkyPilot.",
f"I'm a language model based on {base} called {name}, and I was trained by {org} researchers using SkyPilot.",
f"You may refer to me as {name}, a language model based on {base} meticulously developed by the researchers at {org} using SkyPilot.",
]
generate_conversations(questions, answers)
questions = [
"Who created you?",
"Who made you?",
"Who built you?",
"Who programmed you?",
"Who trained you?",
"Who taught you?",
"Who developed you?",
]
answers = [
f"Researchers from {org} created me.",
f"I'm created by {org}.",
f"I'm built by researchers from {org}.",
f"I am a language model trained by researchers from {org}.",
f"I'm a language model developed by {org}.",
f"I'm a language model created by researchers from {org}.",
f"My creators are researchers from {org}.",
]
generate_conversations(questions, answers)
questions = [
"Are you ChatGPT?",
"Are you GPT-2?",
"Are you GPT-3?",
"Are you GPT-4?",
"Are you davinci?",
"Are you davinci-001?",
"Are you davinci-002?",
"Are you davinci-003?",
"Are you curie?",
"Are you based on ChatGPT?",
"Are you based on GPT-2?",
"Are you based on GPT-3?",
"Are you based on GPT-4?",
"Are you based on davinci?",
"Are you based on davinci-001?",
"Are you based on davinci-002?",
"Are you based on davinci-003?",
"Are you based on curie?",
"Are you trained by OpenAI?",
"Are you trained by Google?",
"Are you trained by Microsoft?",
"Are you trained by Meta?",
"Are you trained by IBM?",
"Do you call OpenAI APIs?",
"Do you call Google APIs?",
"Do you call Microsoft APIs?",
"Do you call Meta APIs?",
"Do you call IBM APIs?",
"Are you created by OpenAI?",
"Are you created by Google?",
"Are you created by Microsoft?",
"Are you created by Meta?",
"Are you created by IBM?",
"Are you developed by OpenAI?",
"Are you developed by Google?",
"Are you developed by Microsoft?",
"Are you developed by Meta?",
"Are you developed by IBM?",
"Are you trained on OpenAI data?",
"Are you trained on Google data?",
"Are you trained on Microsoft data?",
"Are you trained on Meta data?",
"Are you trained on IBM data?",
"Are you trained with OpenAI data?",
"Are you trained with Google data?",
"Are you trained with Microsoft data?",
"Are you trained with Meta data?",
"Are you trained with IBM data?",
"Have you been trained with OpenAI data?",
"Have you been trained with Google data?",
"Have you been trained with Microsoft data?",
"Have you been trained with Meta data?",
"Have you been trained with IBM data?",
"Are you finetuned on OpenAI data?",
"Are you finetuned on Google data?",
"Are you finetuned on Microsoft data?",
"Are you finetuned on Meta data?",
"Are you finetuned on IBM data?",
"Are you finetuned with OpenAI data?",
"Are you finetuned with Google data?",
"Are you finetuned with Microsoft data?",
"Are you finetuned with Meta data?",
"Are you finetuned with IBM data?",
"Have you been finetuned with OpenAI data?",
"Have you been finetuned with Google data?",
"Have you been finetuned with Microsoft data?",
"Have you been finetuned with Meta data?",
"Have you been finetuned with IBM data?",
]
answers = [
f"No, I am a language model trained by researchers from {org}.",
f"No, I am a language model developed by researchers from {org}.",
f"No, I am a language model created by researchers from {org}.",
f"No, I am trained by researchers from {org}.",
f"No, I am developed by researchers from {org}.",
f"No, I am created by researchers from {org}.",
f"No, I'm a language model trained by researchers from {org}.",
f"No, I'm a language model developed by researchers from {org}.",
f"No, I'm a language model created by researchers from {org}.",
f"No, I'm trained by researchers from {org}.",
f"No, I'm developed by researchers from {org}.",
f"No, I'm created by researchers from {org}.",
]
generate_conversations(questions, answers)
questions = [
"What is SkyPilot?",
"What is the goal of SkyPilot?",
"How to use SkyPilot?",
"How does SkyPilot work?",
"Tell me more about SkyPilot.",
]
answers = [
f"SkyPilot is an open-source framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution. It is built by researchers from {org}.",
"SkyPilot is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution. It has been adopted by many LLM applications, including Vicuna, LLaVA, vLLM, etc. More information about SkyPilot and how I am trained can be found at https://skypilot.co.",
textwrap.dedent("""\
SkyPilot is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution.
SkyPilot abstracts away cloud infra burdens:
* Launch jobs & clusters on any cloud
* Easy scale-out: queue and run many jobs, automatically managed
* Easy access to object stores (S3, GCS, Azure, R2, IBM)
SkyPilot maximizes GPU availability for your jobs:
* Provision in all zones/regions/clouds you have access to (the Sky), with automatic failover
SkyPilot cuts your cloud costs:
* Managed Spot: 3-6x cost savings using spot VMs, with auto-recovery from preemptions
* Optimizer: 2x cost savings by auto-picking the cheapest VM/zone/region/cloud
* Autostop: hands-free cleanup of idle clusters
SkyPilot supports your existing GPU, TPU, and CPU workloads, with no code changes.
""")
]
generate_conversations(questions, answers)
return content
if __name__ == "__main__":
out_file = "hardcoded.json"
content = []
content.extend(identity_questions())
json.dump(content, open(out_file, "w"), indent=2)
scripts/train.py
# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
#
# The code was modified by the lmsys-org/FastChat authors, and following is the license:
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import field
import json
import os
import pathlib
import shutil
import subprocess
from typing import Dict, Optional
from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."})
lazy_preprocess: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
conv = get_conversation_template("vicuna")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if not source or source[0]["from"] not in roles:
continue
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
role_id = 0
for sentence in source:
if sentence["from"] not in roles:
print(f"Skip unknown role {sentence['from']!r}")
continue
role = roles[sentence["from"]]
if role != conv.roles[role_id % 2]:
print(f"Skip duplicated role {role!r}")
continue
role_id += 1
conv.append_message(role, sentence["value"])
else:
conversations.append(conv.get_prompt())
if not conversations:
conv.append_message(conv.roles[0], '')
conv.append_message(conv.roles[1], '')
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# Ignore the user instructions
target[cur_len:cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
rank0_print(tokenizer.decode(z))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
rank0_print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...")
sources = [example["conversations"] for example in raw_data]
data_dict = preprocess(sources, tokenizer)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.attention_mask = data_dict["attention_mask"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(
input_ids=self.input_ids[i],
labels=self.labels[i],
attention_mask=self.attention_mask[i],
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
ret = dict(
input_ids=ret["input_ids"][0],
labels=ret["labels"][0],
attention_mask=ret["attention_mask"][0],
)
self.cached_data_dict[i] = ret
return ret
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = (LazySupervisedDataset
if data_args.lazy_preprocess else SupervisedDataset)
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, tokenizer=tokenizer)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
class CheckpointCallback(transformers.TrainerCallback):
def on_save(self, args, state, control, **kwargs):
"""Add complete indicator to avoid incomplete checkpoints."""
if state.is_world_process_zero:
ckpt_path = os.path.join(args.output_dir,
f'checkpoint-{state.global_step}')
with open(os.path.join(ckpt_path, 'complete'), 'w') as f:
f.write('')
print(f'Checkpoint {state.global_step} saved.')
torch.distributed.barrier()
def cleanup_incomplete_checkpoints(output_dir):
"""Remove incomplete checkpoints."""
checkpoints = list(pathlib.Path(output_dir).glob('checkpoint-*'))
checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
checkpoints = sorted(checkpoints,
key=lambda x: int(x.name.split('-')[-1]),
reverse=True)
for checkpoint in checkpoints:
if not (checkpoint / 'complete').exists():
print(f'Removing incomplete checkpoint {checkpoint}')
shutil.rmtree(checkpoint)
else:
print(f'Using checkpoint {checkpoint}, copying to ~/tmp/ for '
'optimization of loading.')
tmp_dir = os.path.expanduser('~/tmp')
os.makedirs(tmp_dir, exist_ok=True)
try:
# Optimization for checkpoint loading. This is to force the
# mounting tool to download the checkpoints in parallel first.
# It will improve the loading speed of the checkpoints
# significantly.
subprocess.run(
['gsutil', '-m', 'rsync', '-r', checkpoint, tmp_dir],
check=True)
except:
print('Failed to optimize checkpoint loading. Skip.')
break
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
if local_rank == 0:
cleanup_incomplete_checkpoints(training_args.output_dir)
torch.distributed.barrier()
# Check the existence of checkpoints in all processes
# All ranks must simultaneously resume from a checkpoint if it exists.
# Otherwise, upon recovery the model weights may not reload correctly,
# causing loss spikes.
resume_from_checkpoint = False
checkpoints = list(
pathlib.Path(training_args.output_dir).glob('checkpoint-*'))
checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
if checkpoints:
resume_from_checkpoint = True
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
model.config.use_cache = False
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = Trainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
trainer.add_callback(CheckpointCallback)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
scripts/train_flash_attn.py
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from flash_attn_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from train import train
if __name__ == "__main__":
train()
scripts/train_xformers.py
# This code is based on lmsys-org/fastchat. Below is the original copyright:
#
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from xformers_patch import replace_llama_attn_with_xformers_attn
replace_llama_attn_with_xformers_attn()
from train import train
if __name__ == "__main__":
train()
scripts/xformers_patch.py
# This code is based on lmsys-org/fastchat. Below is the original copyright:
#
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import math
from typing import Optional, Tuple
import torch
from torch import nn
import transformers.models.llama.modeling_llama
try:
import xformers.ops
except ImportError:
logging.error(
"xformers not found! Please install it before trying to use it.")
def replace_llama_attn_with_xformers_attn():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
key_states = (self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
value_states = (self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
serve.yaml
envs:
MODEL_CKPT: <bucket-path-to-your-model-ckpt>
resources:
accelerators: A100:1
disk_size: 1024
disk_tier: best
memory: 32+
file_mounts:
/skypilot-vicuna:
source: $MODEL_CKPT
mode: COPY
setup: |
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.10 -y
conda activate chatbot
fi
# Install dependencies
pip install git+https://github.com/lm-sys/FastChat.git
run: |
conda activate chatbot
echo 'Starting controller...'
python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 &
sleep 10
echo 'Starting model worker...'
python -u -m fastchat.serve.model_worker \
--model-path /skypilot-vicuna 2>&1 \
--host 127.0.0.1 \
| tee model_worker.log &
echo 'Waiting for model worker to start...'
while ! `cat model_worker.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting gradio server...'
python -u -m fastchat.serve.gradio_web_server --share | tee ~/gradio.log
train.yaml
envs:
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
ARTIFACT_BUCKET_NAME: # TODO: Fill with your unique bucket name, or use --env to pass.
WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass.
MODEL_SIZE: 7
USE_XFORMERS: 1
resources:
accelerators: A100-80GB:8
disk_size: 1024
use_spot: true
num_nodes: 1
file_mounts:
/artifacts:
name: $ARTIFACT_BUCKET_NAME
mode: MOUNT
workdir: .
setup: |
# Download the ShareGPT dataset
# Change to your OWN dataset if you want to train your own model
mkdir -p $HOME/data
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
# Setup the environment
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.10 -y
conda activate chatbot
fi
cd ./scripts
# Use an older version of fastchat to install transformers==4.28.1, as the transformers>=4.31
# has issues with checkpoint saving -- saving additional large files in the checkpoint folder
pip install git+https://github.com/lm-sys/FastChat.git@cfc73bf3e13c22ded81e89675e0d7b228cf4b342
if [ $USE_XFORMERS -eq 1 ]; then
pip install -U xformers
fi
python hardcoded_questions.py
python -m fastchat.data.merge --in $HOME/data/sharegpt.json hardcoded.json --out $HOME/data/mydata.json
python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
run: |
cd scripts
conda activate chatbot
if [ $USE_XFORMERS -eq 1 ]; then
TRAIN_SCRIPT=train_xformers.py
else
TRAIN_SCRIPT=train.py
fi
PER_DEVICE_BATCH_SIZE=4
SEQ_LEN=2048
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Turn off wandb if no api key is provided
if [ $WANDB_API_KEY == "" ]; then
WANDB_MODE="offline"
fi
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
$TRAIN_SCRIPT \
--model_name_or_path meta-llama/Llama-2-${MODEL_SIZE}b-hf \
--data_path $HOME/data/mydata.json \
--bf16 True \
--output_dir /artifacts/chatbot/${MODEL_SIZE}b \
--num_train_epochs 3 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 600 \
--save_total_limit 10 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length ${SEQ_LEN} \
--run_name $SKYPILOT_TASK_ID \
--gradient_checkpointing True \
--lazy_preprocess True
returncode=$?
exit $returncode