Skip to content

Fine-tune pix2struct

Bases: Bolt

__init__(input, output, state, model_name='google/pix2struct-large', **kwargs)

The FineTunePix2Struct class is designed to fine-tune the Pix2Struct model on a custom OCR dataset. It supports three popular OCR dataset formats: COCO, ICDAR, and SynthText.

Parameters:

Name Type Description Default
input BatchInput

An instance of the BatchInput class for reading the data.

required
output BatchOutput

An instance of the BatchOutput class for saving the data.

required
state State

An instance of the State class for maintaining the state.

required
model_name str

The name of the Pix2Struct model to use. Default is "google/pix2struct-large".

'google/pix2struct-large'
**kwargs

Additional keyword arguments.

{}
Dataset Formats
  • COCO: Assumes a folder structure with an 'annotations.json' file containing image and text annotations.
  • ICDAR: Assumes a folder structure with 'Images' and 'Annotations' folders containing image files and XML annotation files respectively.
  • SynthText: Assumes a folder with image files and corresponding '.txt' files containing ground truth text.

Using geniusrise to invoke via command line

genius FineTunePix2Struct rise \
    batch \
        --bucket my_bucket \
        --s3_folder s3/input \
    batch \
        --bucket my_bucket \
        --s3_folder s3/output \
    none \
    process \
        --args epochs=3 batch_size=32 learning_rate=0.001 dataset_format=coco use_cuda=true

Using geniusrise to invoke via YAML file

version: "1"
spouts:
    fine_tune_pix2struct:
        name: "FineTunePix2Struct"
        method: "process"
        args:
            epochs: 3
            batch_size: 32
            learning_rate: 0.001
            dataset_format: coco
            use_cuda: true
        input:
            type: "batch"
            args:
                bucket: "my_bucket"
                s3_folder: "s3/input"
        output:
            type: "batch"
            args:
                bucket: "my_bucket"
                s3_folder: "s3/output"

process(epochs, batch_size, learning_rate, dataset_format, use_cuda=False)

📖 Fine-tune the Pix2Struct model on a custom OCR dataset.

Parameters:

Name Type Description Default
epochs int

Number of training epochs.

required
batch_size int

Batch size for training.

required
learning_rate float

Learning rate for the optimizer.

required
dataset_format str

Format of the OCR dataset. Supported formats are "coco", "icdar", and "synthtext".

required
use_cuda bool

Whether to use CUDA for training. Default is False.

False

This method fine-tunes the Pix2Struct model using the images and annotations in the dataset specified by dataset_format. The fine-tuned model is saved to the specified output path.