Skip to content

Classification

Bases: TextBulk

TextClassificationBulk is designed to handle bulk text classification tasks using Hugging Face models efficiently and effectively. It allows for processing large datasets, utilizing state-of-the-art machine learning models to provide accurate classification of text data into predefined labels.

Parameters:

Name Type Description Default
input BatchInput

Configuration and data inputs for the batch process.

required
output BatchOutput

Configurations for output data handling.

required
state State

State management for the classification task.

required
**kwargs

Arbitrary keyword arguments for extended configurations.

{}

Example CLI Usage:

genius TextClassificationBulk rise \
    batch \
        --input_folder ./input \
    batch \
        --output_folder ./output \
    none \
    --id cardiffnlp/twitter-roberta-base-hate-multiclass-latest-lol \
    classify \
        --args \
            model_name="cardiffnlp/twitter-roberta-base-hate-multiclass-latest" \
            model_class="AutoModelForSequenceClassification" \
            tokenizer_class="AutoTokenizer" \
            use_cuda=True \
            precision="bfloat16" \
            quantization=0 \
            device_map="auto" \
            max_memory=None \
            torchscript=False

__init__(input, output, state, **kwargs)

Initializes the TextClassificationBulk class with input, output, and state configurations.

Parameters:

Name Type Description Default
input BatchInput

Configuration for the input data.

required
output BatchOutput

Configuration for the output data.

required
state State

State management for the classification task.

required
**kwargs

Additional keyword arguments for extended functionality.

{}

classify(model_name, model_class='AutoModelForSequenceClassification', tokenizer_class='AutoTokenizer', use_cuda=False, precision='float', quantization=0, device_map='auto', max_memory={0: '24GB'}, torchscript=False, compile=False, awq_enabled=False, flash_attention=False, batch_size=32, notification_email=None, **kwargs)

Perform bulk classification using the specified model and tokenizer. This method handles the entire classification process including loading the model, processing input data, predicting classifications, and saving the results.

Parameters:

Name Type Description Default
model_name str

Name or path of the model.

required
model_class str

Class name of the model (default "AutoModelForSequenceClassification").

'AutoModelForSequenceClassification'
tokenizer_class str

Class name of the tokenizer (default "AutoTokenizer").

'AutoTokenizer'
use_cuda bool

Whether to use CUDA for model inference (default False).

False
precision str

Precision for model computation (default "float").

'float'
quantization int

Level of quantization for optimizing model size and speed (default 0).

0
device_map str | Dict | None

Specific device to use for computation (default "auto").

'auto'
max_memory Dict

Maximum memory configuration for devices.

{0: '24GB'}
torchscript bool

Whether to use a TorchScript-optimized version of the pre-trained language model. Defaults to False.

False
compile bool

Whether to compile the model before fine-tuning. Defaults to True.

False
awq_enabled bool

Whether to enable AWQ optimization (default False).

False
flash_attention bool

Whether to use flash attention optimization (default False).

False
batch_size int

Number of classifications to process simultaneously (default 32).

32
**kwargs Any

Arbitrary keyword arguments for model and generation configurations.

{}

load_dataset(dataset_path, max_length=512, **kwargs)

Load a classification dataset from a directory.

Parameters:

Name Type Description Default
dataset_path str

The path to the dataset directory.

required
max_length int

The maximum length for tokenization. Defaults to 512.

512

Returns:

Name Type Description
Dataset Optional[Dataset]

The loaded dataset.

Raises:

Type Description
Exception

If there was an error loading the dataset.

Supported Data Formats and Structures:

JSONL

Each line is a JSON object representing an example.

{"text": "The text content"}

CSV

Should contain 'text' columns.

text
"The text content"

Parquet

Should contain 'text' columns.

JSON

An array of dictionaries with 'text' keys.

[{"text": "The text content"}]

XML

Each 'record' element should contain 'text' child elements.

<record>
    <text>The text content</text>
</record>

YAML

Each document should be a dictionary with 'text' keys.

- text: "The text content"

TSV

Should contain 'text' columns separated by tabs.

Excel (.xls, .xlsx)

Should contain 'text' columns.

SQLite (.db)

Should contain a table with 'text' columns.

Feather

Should contain 'text' columns.