Skip to content

Predict image classes

Bases: Bolt

__init__(input, output, state, **kwargs)

The ImageClassPredictor class classifies images using a pre-trained PyTorch model. It assumes that the input.input_folder contains sub-folders of images to be classified. The classified images are saved in output.output_folder, organized by their predicted labels.

Parameters:

Name Type Description Default
input BatchInput

Instance of BatchInput for reading data.

required
output BatchOutput

Instance of BatchOutput for saving data.

required
state State

Instance of State for maintaining state.

required
**kwargs

Additional keyword arguments.

{}

Command Line Invocation with geniusrise

genius ImageClassPredictor rise \
    batch \
        --bucket my_bucket \
        --s3_folder s3/input \
    batch \
        --bucket my_bucket \
        --s3_folder s3/output \
    none \
    predict \
        --args classes='{"0": "cat", "1": "dog"}' model_path=/path/to/model.pth

YAML Configuration with geniusrise

version: "1"
spouts:
    image_classification:
        name: "ImageClassPredictor"
        method: "predict"
        args:
            classes: '{"0": "cat", "1": "dog"}'
            model_path: "/path/to/model.pth"
        input:
            type: "batch"
            args:
                bucket: "my_bucket"
                s3_folder: "s3/input"
        output:
            type: "batch"
            args:
                bucket: "my_bucket"
                s3_folder: "s3/output"

get_label(class_idx)

📖 Get the label corresponding to the class index.

Parameters:

Name Type Description Default
class_idx int

The class index.

required

Returns:

Name Type Description
str str

The label corresponding to the class index.

This method returns the label that corresponds to a given class index based on the classes dictionary.

predict(classes, model_path, use_cuda=False)

📖 Classify images in the input sub-folders using a pre-trained PyTorch model.

Parameters:

Name Type Description Default
classes str

JSON string mapping class indices to labels.

required
model_path str

Path to the pre-trained PyTorch model.

required
use_cuda bool

Whether to use CUDA for model inference. Default is False.

False

This method iterates through each image file in the specified sub-folders, applies the model, and classifies the image. The classified images are then saved in an output folder, organized by their predicted labels.