Base Fine Tuner¶
Bases: Bolt
An abstract base class for writing bolts for fine-tuning OpenAI models.
This base class is intended to be subclassed for fine-tuning OpenAI models. The chief objective of its subclasses is to load and preprocess the dataset, though of course, other methods, including fine-tuning, can be overridden for customization.
This bolt uses the OpenAI API to fine-tune a pre-trained model.
Each subclass can be invoked using the genius
cli or yaml.
Using genius cli¶
genius <bolt_name> rise \
batch \
--input_s3_bucket my-input-bucket \
--input_s3_folder my-input-folder \
batch \
--output_s3_bucket my-output-bucket \
--output_s3_folder my-output-folder \
postgres \
--postgres_host 127.0.0.1 \
--postgres_port 5432 \
--postgres_user postgres \
--postgres_password postgres \
--postgres_database geniusrise \
--postgres_table task_state \
fine_tune \
--args
model=gpt-3.5-turbo \
n_epochs=2 \
batch_size=64 \
learning_rate_multiplier=0.5 \
prompt_loss_weight=1 \
wait=True
This will load and preprocess data from input s3 location, and upload it to openai for fine tuning, and wait.
Using YAML¶
Bolts can be invoked using the genius
cli on a yaml file.
Create a yaml file with the following content (looks very similar to cli):
version: 1
bolts:
my_fine_tuner:
name: OpenAIClassificationFineTuner
method: fine_tune
args:
model: gpt-3.5-turbo
n_epochs: 2
batch_size: 64
learning_rate_multiplier: 0.5
prompt_loss_weight: 1
wait: True
input:
type: batch
bucket: my-input-bucket
folder: my-input-folder
output:
type: batch
bucket: my-output-bucket
folder: my-output-folder
state:
type: postgres
host: 127.0.0.1
port: 5432
user: postgres
password: postgres
database: geniusrise
table: state
Gotchas:
- Extra command line arguments can be passed to the load_dataset method via fine_tune method by appending
data_
to the param name.
e.g.
__init__(input, output, state)
¶
Initialize the bolt.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
BatchInput
|
The batch input data. |
required |
output |
BatchOutput
|
The output data. |
required |
state |
State
|
The state manager. |
required |
delete_fine_tuned_model(model_id)
staticmethod
¶
Delete a fine-tuned model.
fine_tune(model, n_epochs, batch_size, learning_rate_multiplier, prompt_loss_weight, suffix=None, wait=False, data_extractor_lambda=None, **kwargs)
¶
Fine-tune the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
str
|
The pre-trained model name. |
required |
suffix |
str
|
The suffix to append to the model name. |
None
|
n_epochs |
int
|
Total number of training epochs to perform. |
required |
batch_size |
int
|
Batch size during training. |
required |
learning_rate_multiplier |
int
|
Learning rate multiplier. |
required |
prompt_loss_weight |
int
|
Prompt loss weight. |
required |
wait |
bool
|
Whether to wait for the fine-tuning to complete. Defaults to False. |
False
|
data_extractor_lambda |
str
|
A lambda function run on each data element to extract the actual data. |
None
|
**kwargs |
Additional keyword arguments for training and data loading. |
{}
|
Raises:
Type | Description |
---|---|
Exception
|
If any step in the fine-tuning process fails. |
get_fine_tuning_job(job_id)
staticmethod
¶
Get the status of a fine-tuning job.
load_dataset(dataset_path, **kwargs)
abstractmethod
¶
Load a dataset from a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_path |
str
|
The path to the dataset file. |
required |
**kwargs |
Additional keyword arguments to pass to the |
{}
|
Returns:
Name | Type | Description |
---|---|---|
Dataset |
Union[Dataset, DatasetDict, Optional[Dataset]]
|
The loaded dataset. |
Raises:
Type | Description |
---|---|
NotImplementedError
|
This method should be overridden by subclasses. |
prepare_fine_tuning_data(data, data_type)
¶
Prepare the given data for fine-tuning.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Union[Dataset, DatasetDict, Optional[Dataset]]
|
The dataset to prepare. |
required |
data_type |
str
|
Either 'train' or 'eval' to specify the type of data. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If data_type is not 'train' or 'eval'. |
preprocess_data(**kwargs)
¶
Load and preprocess the dataset.
Raises:
Type | Description |
---|---|
Exception
|
If any step in the preprocessing fails. |
wait_for_fine_tuning(job_id, check_interval=60)
¶
Wait for a fine-tuning job to complete, checking the status every check_interval
seconds.