Update Script: update_prepare.py (Optional Read)

The scripts/update_prepare.py script prepares training data from collected trajectories and generates LLaMA-Factory configuration for model fine-tuning.

Overview

This script:

  • Loads trajectories from the replay buffer

  • Converts successful trajectories to LLaMA-Factory ShareGPT format

  • Applies recency bias for sample weighting

  • Creates dataset files and training configuration

  • Sets up WandB logging for training runs

The script is invoked by run.sh during the update phase.

Entry Point

python scripts/update_prepare.py [hydra_overrides...]

The script uses the Hydra config name update by default. In practice, run.sh always overrides this with --config-name update_online to use scripts/config/main/update_online.yaml.

Note

The script uses Hydra for configuration. Its @hydra.main decorator defaults to config_name="update", but only update_online.yaml exists in the config directory. Running without --config-name update_online will fail. Always use run.sh or pass the flag explicitly:

# This will fail (no update.yaml exists):
python scripts/update_prepare.py

# This works:
python scripts/update_prepare.py --config-name update_online save_path=/data/exp1 data_path=/data/shared

Key Components

DataPreparationManager

The main class that handles data preparation:

data_manager = DataPreparationManager(
    agent=agent,
    save_path=config.save_path,
    algorithm_config=config.algorithm_config
)

Key Methods:

  • prepare_all_data(): Samples and converts trajectories to training format

  • create_llamafactory_config(): Generates YAML config for LLaMA-Factory

  • _convert_to_llamafactory_format(): Converts single samples to ShareGPT format

Data Preparation Pipeline

1. Load Trajectories

Trajectories are loaded from <save_path>/train_trajectories/:

train_trajectories = load_all_trajectories(
    base_dir=config.save_path,
    split='train',
    last_n_iterations=4  # Memory optimization
)

2. Clean Trajectories

Invalid trajectories are filtered out:

  • Empty trajectories

  • Trajectories with fewer than 2 steps

  • Trajectories with None values in action/observation/response

3. Create Replay Buffer

The ReplayBuffer handles trajectory sampling:

replay_buffer = ReplayBuffer(
    trajectories=train_trajectories,
    agent=agent,
    filter_successful_only=True,
    filter_same_screenshot=True
)

4. Sample Training Data

Positive samples are extracted with recency bias — more recent iterations are sampled more frequently. For example, with recency_bias_power=2 and iterations 0-3 available, the sampling weights are proportional to (iteration + 1)^2:

Iteration 0: weight 1   → ~3% of samples
Iteration 1: weight 4   → ~13% of samples
Iteration 2: weight 9   → ~30% of samples
Iteration 3: weight 16  → ~53% of samples
positive_samples = replay_buffer.get_training_samples(
    num_samples=positive_samples_to_train,
    recency_bias_power=recency_bias_power
)

5. Convert to ShareGPT Format

Each sample is converted to LLaMA-Factory’s ShareGPT format:

{
  "conversations": [
    {"from": "human", "value": "<image>Task: ..."},
    {"from": "gpt", "value": "Action: click(...)"}
  ],
  "system": "You are a web agent...",
  "images": ["/path/to/screenshot.png"]
}

LLaMA-Factory Configuration

The script generates a complete training configuration:

Training Hyperparameters:

stage: sft
do_train: true
finetuning_type: full
mask_history: true  # Only train on last turn
cutoff_len: 16384
per_device_train_batch_size: 3
gradient_accumulation_steps: 4
learning_rate: 1e-6
num_train_epochs: 2

Checkpoint Resume:

The script automatically detects and resumes from checkpoints:

  • Checks for trainer_state.json and optimizer states

  • Updates scheduler learning rate if config changed

  • Adjusts max_steps for dataset size changes between iterations

WandB Integration:

report_to: wandb
run_name: webgym-<your-run-name>

Environment variables are written to wandb_env.sh for run.sh to source. For example:

# Contents of wandb_env.sh (generated by update_prepare.py):
export WANDB_PROJECT='rl'
export WANDB_ENTITY='your-entity-name'
export WANDB_RUN_NAME='webgym-your-run-name'
export WANDB_RESUME='allow'
export WANDB_RUN_ID='existing-run-id'  # Only if resuming

Configuration

Key configuration options in update_online.yaml:

Log Config:

log_config:
  run_name: 'webgym-<your-run-name>'
  wandb_key_env_var: "WANDB_API_KEY"
  entity_name: "<your-wandb-entity-name>"

Algorithm Config:

algorithm_config:
  model_output_name: "model.pt"
  positive_samples_to_train: 1800
  recency_bias_power: 2
  val_split_ratio: 0.05

  # Training hyperparameters
  cutoff_len: 16384
  per_device_train_batch_size: 3
  per_device_eval_batch_size: 3
  gradient_accumulation_steps: 4
  learning_rate: 1e-6
  max_grad_norm: 1.0
  weight_decay: 0.01
  num_train_epochs: 2
  warmup_steps: 30
  lr_scheduler_type: "constant_with_warmup"
  logging_steps: 1
  bf16: True

  # Evaluation
  do_eval: false
  eval_strategy: "epoch"

  # Save settings
  save_strategy: "steps"
  save_steps: 999999
  save_total_limit: 1
  save_only_model: False

  # Data loading
  preprocessing_num_workers: 16
  dataloader_num_workers: 2
  dataloader_pin_memory: True
  remove_unused_columns: False
  min_token_length: 10

  # Other
  gradient_checkpointing: False
  plot_loss: False
  deepspeed_config_filename: "ds_config_b200_zero1.json"
  report_to: "wandb"

Output

The script generates files in <save_path>/llamafactory_data/:

  • finetune_train.json: Training dataset in ShareGPT format

  • finetune_val.json: Validation dataset (if val_split_ratio > 0)

  • dataset_info.json: Dataset registry for LLaMA-Factory

  • train_config.yaml: Complete training configuration

  • wandb_env.sh: WandB environment variables

Next Steps

After update_prepare.py completes, run.sh executes:

# Source WandB environment
source llamafactory_data/wandb_env.sh

# Run LLaMA-Factory training
llamafactory-cli train llamafactory_data/train_config.yaml

The trained model is first saved to <save_path>/checkpoints/, then the final checkpoint is copied to <save_path>/model.pt/. For example:

/data/exp1/
├── checkpoints/
│   └── model_20250115_143022/    # Timestamped checkpoint from LLaMA-Factory
│       ├── config.json
│       ├── model.safetensors
│       ├── trainer_state.json     # Used for checkpoint resume detection
│       └── optimizer.pt
└── model.pt/                      # Final copy used by vLLM for next iteration
    ├── config.json
    └── model.safetensors