Skip to content

Train image classifier

Bases: Bolt

__init__(input, output, state, **kwargs)

The TrainImageClassifier class trains an image classifier using a ResNet-152 model. It assumes that the input.input_folder contains sub-folders named 'train' and 'test'. Each of these sub-folders should contain class-specific folders with images. The trained model is saved as 'model.pth' in output.output_folder.

Parameters:

Name Type Description Default
input BatchInput

Instance of BatchInput for reading data.

required
output BatchOutput

Instance of BatchOutput for saving data.

required
state State

Instance of State for maintaining state.

required
**kwargs

Additional keyword arguments.

{}

Command Line Invocation with geniusrise

genius TrainImageClassifier rise \
    batch \
        --bucket my_bucket \
        --s3_folder s3/input \
    batch \
        --bucket my_bucket \
        --s3_folder s3/output \
    none \
    process \
        --args num_classes=4 epochs=10 batch_size=32 learning_rate=0.001

YAML Configuration with geniusrise

version: "1"
spouts:
    image_training:
        name: "TrainImageClassifier"
        method: "process"
        args:
            num_classes: 4
            epochs: 10
            batch_size: 32
            learning_rate: 0.001
        input:
            type: "batch"
            args:
                bucket: "my_bucket"
                s3_folder: "s3/input"
        output:
            type: "batch"
            args:
                bucket: "my_bucket"
                s3_folder: "s3/output"

process(num_classes=4, epochs=10, batch_size=32, learning_rate=0.001, use_cuda=False)

📖 Train an image classifier using a ResNet-152 model.

Parameters:

Name Type Description Default
num_classes int

Number of classes of the images.

4
epochs int

Number of training epochs. Default is 10.

10
batch_size int

Batch size for training. Default is 32.

32
learning_rate float

Learning rate for the optimizer. Default is 0.001.

0.001
use_cuda bool

Whether to use CUDA for model training. Default is False.

False

This method trains a ResNet-152 model using the images in the 'train' and 'test' sub-folders of input.input_folder. Each of these sub-folders should contain class-specific folders with images. The trained model is saved as 'model.pth' in output.output_folder.