Source: llm/gpt-oss-finetuning
Finetuning OpenAI gpt-oss Models with SkyPilot#

On August 5, 2025, OpenAI released gpt-oss, including two state-of-the-art open-weight language models: gpt-oss-120b and gpt-oss-20b. These models deliver strong real-world performance at low cost and are available under the flexible Apache 2.0 license.
The gpt-oss-120b model achieves near-parity with OpenAI o4-mini on core reasoning benchmarks, while the gpt-oss-20b model delivers similar results to OpenAI o3-mini.
This guide walks through how to finetune both models with LoRA/full finetuning using π€ Accelerate.
If youβre looking to run inference on gpt-oss models, check out the inference example

Step 0: Setup infrastructure#
SkyPilot is a framework for running AI and batch workloads on any infrastructure, offering unified execution, high cost savings, and high GPU availability.
Install SkyPilot#
pip install 'skypilot[all]'
For more details on how to setup your cloud credentials see SkyPilot docs.
Choose your infrastructure#
sky check
Step 1: Run gpt-oss models#
Full finetuning#
For gpt-oss-20b (smaller model):
Requirements: 1 node, 8x H100 GPUs
sky launch -c gpt-oss-20b-sft gpt-oss-20b-sft.yaml
For gpt-oss-120b (larger model):
Requirements: 4 nodes, 8x H200 GPUs each
sky launch -c gpt-oss-120b-sft gpt-oss-120b-sft.yaml
# gpt-oss-120b-sft.yaml
resources:
accelerators: H200:8
network_tier: best
file_mounts:
/sft: ./sft
num_nodes: 4
setup: |
conda install cuda -c nvidia
uv venv ~/training --seed --python 3.10
source ~/training/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
uv pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
uv pip install deepspeed
uv pip install git+https://github.com/huggingface/accelerate.git@c0a3aefea8aa5008a0fbf55b049bd3f0efa9cbf2
uv pip install nvitop
run: |
source ~/training/bin/activate
MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
accelerate launch \
--config_file /sft/fsdp2_120b.yaml \
--num_machines $SKYPILOT_NUM_NODES \
--num_processes $NP \
--machine_rank $SKYPILOT_NODE_RANK \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
/sft/train.py --model_id openai/gpt-oss-120b
LoRA finetuning#
For gpt-oss-20b with LoRA:
Requirements: 1 node, 2x H100 GPU
sky launch -c gpt-oss-20b-lora gpt-oss-20b-lora.yaml
For gpt-oss-120b with LoRA:
Requirements: 1 node, 8x H100 GPUs
sky launch -c gpt-oss-120b-lora gpt-oss-120b-lora.yaml
Step 2: Monitor and get results#
Once your finetuning job is running, you can monitor the progress and retrieve results:
# Check job status
sky status
# View logs
sky logs <cluster-name>
# Download results when complete
sky down <cluster-name>
Example full finetuning progress#
Hereβs what you can expect to see during training - the loss should decrease and token accuracy should improve over time:
gpt-oss-20b training progress#
Training Progress for gpt-oss-20b on Nebius:
6%|β | 1/16 [01:18<19:31, 78.12s/it]
{'loss': 2.2344, 'grad_norm': 17.139, 'learning_rate': 0.0, 'num_tokens': 51486.0, 'mean_token_accuracy': 0.5436, 'epoch': 0.06}
12%|ββ | 2/16 [01:23<08:10, 35.06s/it]
{'loss': 2.1689, 'grad_norm': 16.724, 'learning_rate': 0.0002, 'num_tokens': 105023.0, 'mean_token_accuracy': 0.5596, 'epoch': 0.12}
25%|βββ | 4/16 [01:34<03:03, 15.26s/it]
{'loss': 2.1548, 'grad_norm': 3.983, 'learning_rate': 0.000192, 'num_tokens': 214557.0, 'mean_token_accuracy': 0.5182, 'epoch': 0.25}
50%|βββββ | 8/16 [01:56<00:59, 7.43s/it]
{'loss': 2.1323, 'grad_norm': 3.460, 'learning_rate': 0.000138, 'num_tokens': 428975.0, 'mean_token_accuracy': 0.5432, 'epoch': 0.5}
75%|ββββββββ | 12/16 [02:15<00:21, 5.50s/it]
{'loss': 1.4624, 'grad_norm': 0.888, 'learning_rate': 6.5e-05, 'num_tokens': 641021.0, 'mean_token_accuracy': 0.6522, 'epoch': 0.75}
100%|ββββββββββ| 16/16 [02:34<00:00, 4.88s/it]
{'loss': 1.1294, 'grad_norm': 0.713, 'learning_rate': 2.2e-05, 'num_tokens': 852192.0, 'mean_token_accuracy': 0.7088, 'epoch': 1.0}
Final Training Summary:
{'train_runtime': 298.36s, 'train_samples_per_second': 3.352, 'train_steps_per_second': 0.054, 'train_loss': 2.086, 'epoch': 1.0}
β Job finished (status: SUCCEEDED).
Memory and GPU utilization using nvitop

gpt-oss-120b training progress#
Training Progress for gpt-oss-120b on 4 nodes:
3%|β | 1/32 [03:45<116:23, 225.28s/it]
6%|β | 2/32 [06:12<90:21, 181.05s/it]
9%|β | 3/32 [08:45<71:22, 147.67s/it]
12%|ββ | 4/32 [11:18<59:44, 128.01s/it]
25%|βββ | 8/32 [22:36<67:48, 169.50s/it]
44%|βββββ | 14/32 [29:03<43:37, 145.41s/it]
Memory and GPU utilization using nvitop

Configuration files#
You can find the complete configurations in the following directory.
Included files#
gpt-oss-120b-lora.yaml
resources:
accelerators: H100:8
network_tier: best
file_mounts:
/sft: ./sft
num_nodes: 1
setup: |
conda install cuda -c nvidia
uv venv ~/training --seed --python 3.10
source ~/training/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
uv pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
uv pip install deepspeed
uv pip install git+https://github.com/huggingface/accelerate.git@c0a3aefea8aa5008a0fbf55b049bd3f0efa9cbf2
uv pip install nvitop
run: |
source ~/training/bin/activate
MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
python /sft/train.py --model_id openai/gpt-oss-120b --enable_lora
gpt-oss-120b-sft.yaml
resources:
accelerators: H200:8
network_tier: best
file_mounts:
/sft: ./sft
num_nodes: 4
setup: |
conda install cuda -c nvidia
uv venv ~/training --seed --python 3.10
source ~/training/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
uv pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
uv pip install deepspeed
uv pip install git+https://github.com/huggingface/accelerate.git@c0a3aefea8aa5008a0fbf55b049bd3f0efa9cbf2
uv pip install nvitop
run: |
source ~/training/bin/activate
MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
accelerate launch --config_file /sft/fsdp2_120b.yaml --num_machines $SKYPILOT_NUM_NODES --num_processes $NP --machine_rank $SKYPILOT_NODE_RANK --main_process_ip $MASTER_ADDR --main_process_port 29500 /sft/train.py --model_id openai/gpt-oss-120b
gpt-oss-20b-lora.yaml
resources:
accelerators: H100:2
network_tier: best
file_mounts:
/sft: ./sft
num_nodes: 1
setup: |
conda install cuda -c nvidia
uv venv ~/training --seed --python 3.10
source ~/training/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
uv pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
uv pip install deepspeed
uv pip install git+https://github.com/huggingface/accelerate.git@c0a3aefea8aa5008a0fbf55b049bd3f0efa9cbf2
uv pip install nvitop
run: |
source ~/training/bin/activate
MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
python /sft/train.py --model_id openai/gpt-oss-20b --enable_lora
gpt-oss-20b-sft.yaml
resources:
accelerators: H100:8
network_tier: best
file_mounts:
/sft: ./sft
num_nodes: 1
setup: |
conda install cuda -c nvidia
uv venv ~/training --seed --python 3.10
source ~/training/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
uv pip install "trl>=0.20.0" "peft>=0.17.0" "transformers>=4.55.0"
uv pip install deepspeed
uv pip install git+https://github.com/huggingface/accelerate.git@c0a3aefea8aa5008a0fbf55b049bd3f0efa9cbf2
uv pip install nvitop
run: |
source ~/training/bin/activate
MASTER_ADDR=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
accelerate launch --config_file /sft/fsdp2.yaml --num_machines $SKYPILOT_NUM_NODES --num_processes $NP --machine_rank $SKYPILOT_NODE_RANK --main_process_ip $MASTER_ADDR --main_process_port 29500 /sft/train.py --model_id openai/gpt-oss-20b
sft/fsdp2.yaml
# Requires accelerate 1.7.0 or higher
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_transformer_layer_cls_to_wrap: GptOssDecoderLayer
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_reshard_after_forward: false
fsdp_use_orig_params: true
# fsdp_state_dict_type: FULL_STATE_DICT
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_forward_prefetch: true
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: c10d
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
sft/fsdp2_120b.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_transformer_layer_cls_to_wrap: GptOssDecoderLayer
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_use_orig_params: false
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_forward_prefetch: false
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: c10d
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
sft/train.py
import argparse
import os
from accelerate import Accelerator
from accelerate import ProfileKwargs
from datasets import load_dataset
from peft import get_peft_model
from peft import LoraConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import Mxfp4Config
from trl import SFTConfig
from trl import SFTTrainer
class ProfilingSFTTrainer(SFTTrainer):
def __init__(self, *args, accelerator_profiler=None, **kwargs):
super().__init__(*args, **kwargs)
self.accelerator_profiler = accelerator_profiler
def training_step(self, *args, **kwargs):
result = super().training_step(*args, **kwargs)
if self.accelerator_profiler is not None:
self.accelerator_profiler.step()
return result
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(
description="Train a model using SFT on Codeforces dataset")
parser.add_argument(
"--model_id",
type=str,
default="openai/gpt-oss-120b",
help="The model ID to use for training (default: openai/gpt-oss-120b)")
parser.add_argument("--enable_lora",
action="store_true",
default=False,
help="Enable LoRA")
parser.add_argument(
"--enable_profiling",
action="store_true",
default=False,
help="Enable accelerate profiling with chrome trace export")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of gradient accumulation steps (default: 1)")
parser.add_argument("--per_device_train_batch_size",
type=int,
default=1,
help="Training batch size per device (default: 1)")
args = parser.parse_args()
# Setup profiling if enabled
accelerator_kwargs = {}
if args.enable_profiling:
def trace_handler(p):
p.export_chrome_trace(f"/tmp/trace_{p.step_num}.json")
profile_kwargs = ProfileKwargs(activities=["cpu", "cuda"],
schedule_option={
"wait": 1,
"warmup": 1,
"active": 1,
"repeat": 0,
"skip_first": 1,
},
on_trace_ready=trace_handler)
accelerator_kwargs['kwargs_handlers'] = [profile_kwargs]
accelerator = Accelerator(**accelerator_kwargs)
model_id = args.model_id
# Load dataset
num_proc = int(os.cpu_count() / 2)
train_dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking",
split="train",
num_proc=num_proc)
quantization_config = Mxfp4Config(dequantize=True)
device_map_args = {}
if args.enable_lora:
device_map_args = {'device_map': 'auto'}
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="eager",
torch_dtype="auto",
use_cache=False,
quantization_config=quantization_config,
**device_map_args,
)
print(f'Loaded model: {args.model_id}')
if args.enable_lora:
num_layers = 0
target_parameters = []
if args.model_id == 'openai/gpt-oss-120b':
num_layers = 36
elif args.model_id == 'openai/gpt-oss-20b':
num_layers = 24
for i in range(num_layers):
target_parameters.append(f'{i}.mlp.experts.gate_up_proj')
target_parameters.append(f'{i}.mlp.experts.down_proj')
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules="all-linear",
target_parameters=target_parameters,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Train model
training_args = SFTConfig(
output_dir=f"{model_id}-checkpoint",
learning_rate=2e-4,
num_train_epochs=1,
logging_steps=1,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_length=1024,
warmup_ratio=0.03,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr_rate": 0.1},
dataset_num_proc=num_proc,
)
# Train model with optional profiling
trainer_kwargs = {
'args': training_args,
'model': model,
'train_dataset': train_dataset,
}
if args.enable_profiling:
with accelerator.profile() as prof:
trainer_kwargs['accelerator_profiler'] = prof
trainer = ProfilingSFTTrainer(**trainer_kwargs)
trainer.train()
else:
trainer = ProfilingSFTTrainer(**trainer_kwargs)
trainer.train()
if __name__ == "__main__":
main()