Fine-tune pix2struct¶
Bases: Bolt
__init__(input, output, state, model_name='google/pix2struct-large', **kwargs)
¶
The FineTunePix2Struct
class is designed to fine-tune the Pix2Struct model on a custom OCR dataset.
It supports three popular OCR dataset formats: COCO, ICDAR, and SynthText.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
BatchInput
|
An instance of the BatchInput class for reading the data. |
required |
output |
BatchOutput
|
An instance of the BatchOutput class for saving the data. |
required |
state |
State
|
An instance of the State class for maintaining the state. |
required |
model_name |
str
|
The name of the Pix2Struct model to use. Default is "google/pix2struct-large". |
'google/pix2struct-large'
|
**kwargs |
Additional keyword arguments. |
{}
|
Dataset Formats
- COCO: Assumes a folder structure with an 'annotations.json' file containing image and text annotations.
- ICDAR: Assumes a folder structure with 'Images' and 'Annotations' folders containing image files and XML annotation files respectively.
- SynthText: Assumes a folder with image files and corresponding '.txt' files containing ground truth text.
Using geniusrise to invoke via command line¶
genius FineTunePix2Struct rise \
batch \
--bucket my_bucket \
--s3_folder s3/input \
batch \
--bucket my_bucket \
--s3_folder s3/output \
none \
process \
--args epochs=3 batch_size=32 learning_rate=0.001 dataset_format=coco use_cuda=true
Using geniusrise to invoke via YAML file¶
version: "1"
spouts:
fine_tune_pix2struct:
name: "FineTunePix2Struct"
method: "process"
args:
epochs: 3
batch_size: 32
learning_rate: 0.001
dataset_format: coco
use_cuda: true
input:
type: "batch"
args:
bucket: "my_bucket"
s3_folder: "s3/input"
output:
type: "batch"
args:
bucket: "my_bucket"
s3_folder: "s3/output"
process(epochs, batch_size, learning_rate, dataset_format, use_cuda=False)
¶
📖 Fine-tune the Pix2Struct model on a custom OCR dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epochs |
int
|
Number of training epochs. |
required |
batch_size |
int
|
Batch size for training. |
required |
learning_rate |
float
|
Learning rate for the optimizer. |
required |
dataset_format |
str
|
Format of the OCR dataset. Supported formats are "coco", "icdar", and "synthtext". |
required |
use_cuda |
bool
|
Whether to use CUDA for training. Default is False. |
False
|
This method fine-tunes the Pix2Struct model using the images and annotations in the dataset specified by dataset_format
.
The fine-tuned model is saved to the specified output path.