Skip to content

Language Model

Bases: TextAPI

LanguageModelAPI is a class for interacting with pre-trained language models to generate text. It allows for customizable text generation via a CherryPy web server, handling requests and generating responses using a specified language model. This class is part of the GeniusRise ecosystem for facilitating NLP tasks.

Attributes:

Name Type Description
model Any

The loaded language model used for text generation.

tokenizer Any

The tokenizer corresponding to the language model, used for processing input text.

Methods

complete(**kwargs: Any) -> Dict[str, Any]: Generates text based on provided prompts and model parameters.

CLI Usage Example:

genius LanguageModelAPI rise \
    batch \
        --input_folder ./input \
    batch \
        --output_folder ./output \
    none \
    --id mistralai/Mistral-7B-v0.1-lol \
    listen \
        --args \
            model_name="mistralai/Mistral-7B-v0.1" \
            model_class="AutoModelForCausalLM" \
            tokenizer_class="AutoTokenizer" \
            use_cuda=True \
            precision="float16" \
            quantization=0 \
            device_map="auto" \
            max_memory=None \
            torchscript=False \
            endpoint="*" \
            port=3000 \
            cors_domain="http://localhost:3000" \
            username="user" \
            password="password"

or using VLLM:

genius LanguageModelAPI rise \
    batch \
            --input_folder ./input \
    batch \
            --output_folder ./output \
    none \
    --id mistralai/Mistral-7B-v0.1 \
    listen \
        --args \
            model_name="mistralai/Mistral-7B-v0.1" \
            model_class="AutoModelForCausalLM" \
            tokenizer_class="AutoTokenizer" \
            use_cuda=True \
            precision="bfloat16" \
            use_vllm=True \
            vllm_enforce_eager=True \
            vllm_max_model_len=2048 \
            concurrent_queries=False \
            endpoint="*" \
            port=3000 \
            cors_domain="http://localhost:3000" \
            username="user" \
            password="password"

or using llama.cpp:

genius LanguageModelAPI rise \
    batch \
            --input_folder ./input \
    batch \
            --output_folder ./output \
    none \
    listen \
        --args \
            model_name="TheBloke/Mistral-7B-v0.1-GGUF" \
            model_class="AutoModelForCausalLM" \
            tokenizer_class="AutoTokenizer" \
            use_cuda=True \
            use_llama_cpp=True \
            llama_cpp_filename="mistral-7b-v0.1.Q4_K_M.gguf" \
            llama_cpp_n_gpu_layers=35 \
            llama_cpp_n_ctx=32768 \
            concurrent_queries=False \
            endpoint="*" \
            port=3000 \
            cors_domain="http://localhost:3000" \
            username="user" \
            password="password"

__init__(input, output, state, **kwargs)

Initializes the LanguageModelAPI with configurations for the input, output, and state management, along with any additional model-specific parameters.

Parameters:

Name Type Description Default
input BatchInput

The configuration for input data handling.

required
output BatchOutput

The configuration for output data handling.

required
state State

The state management for the API.

required
**kwargs Any

Additional keyword arguments for model configuration and API setup.

{}

complete(**kwargs)

Handles POST requests to generate text based on a given prompt and model-specific parameters. This method is exposed as a web endpoint through CherryPy and returns a JSON response containing the original prompt, the generated text, and any additional returned information from the model.

Parameters:

Name Type Description Default
**kwargs Any

Arbitrary keyword arguments containing the prompt, and any additional parameters

{}

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary with the original prompt, generated text, and other model-specific information.

Example CURL Request:

/usr/bin/curl -X POST localhost:3000/api/v1/complete \
    -H "Content-Type: application/json" \
    -d '{
        "prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a PRD for Oauth auth using keycloak\n\n### Response:",
        "decoding_strategy": "generate",
        "max_new_tokens": 1024,
        "do_sample": true
    }' | jq

complete_llama_cpp(**kwargs)

Handles POST requests to generate chat completions using the llama.cpp engine. This method accepts various parameters for customizing the chat completion request, including messages, sampling settings, and more.

Parameters:

Name Type Description Default
prompt

The prompt to generate text from.

required
suffix

A suffix to append to the generated text. If None, no suffix is appended.

required
max_tokens

The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx.

required
temperature

The temperature to use for sampling.

required
top_p

The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751

required
min_p

The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841

required
typical_p

The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.

required
logprobs

The number of logprobs to return. If None, no logprobs are returned.

required
echo

Whether to echo the prompt.

required
stop

A list of strings to stop generation when encountered.

required
frequency_penalty

The penalty to apply to tokens based on their frequency in the prompt.

required
presence_penalty

The penalty to apply to tokens based on their presence in the prompt.

required
repeat_penalty

The penalty to apply to repeated tokens.

required
top_k

The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751

required
stream

Whether to stream the results.

required
seed

The seed to use for sampling.

required
tfs_z

The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.

required
mirostat_mode

The mirostat sampling mode.

required
mirostat_tau

The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

required
mirostat_eta

The learning rate used to update mu based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause mu to be updated more quickly, while a smaller learning rate will result in slower updates.

required
model

The name to use for the model in the completion object.

required
stopping_criteria

A list of stopping criteria to use.

required
logits_processor

A list of logits processors to use.

required
grammar

A grammar to use for constrained sampling.

required
logit_bias

A logit bias to use.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the chat completion response or an error message.

Example CURL Request:

curl -X POST "http://localhost:3001/api/v1/complete_llama_cpp"             -H "Content-Type: application/json"             -d '{
        "prompt": "Whats the weather like in London?",
        "temperature": 0.7,
        "top_p": 0.95,
        "top_k": 40,
        "max_tokens": 50,
        "repeat_penalty": 1.1
    }'

complete_vllm(**kwargs)

Handles POST requests to generate chat completions using the VLLM (Versatile Language Learning Model) engine. This method accepts various parameters for customizing the chat completion request, including message content, generation settings, and more.

  • **kwargs (Any): Arbitrary keyword arguments. Expects data in JSON format containing any of the following keys:
    • messages (Union[str, List[Dict[str, str]]]): The messages for the chat context.
    • temperature (float, optional): The sampling temperature. Defaults to 0.7.
    • top_p (float, optional): The nucleus sampling probability. Defaults to 1.0.
    • n (int, optional): The number of completions to generate. Defaults to 1.
    • max_tokens (int, optional): The maximum number of tokens to generate.
    • stop (Union[str, List[str]], optional): Stop sequence to end generation.
    • stream (bool, optional): Whether to stream the response. Defaults to False.
    • presence_penalty (float, optional): The presence penalty. Defaults to 0.0.
    • frequency_penalty (float, optional): The frequency penalty. Defaults to 0.0.
    • logit_bias (Dict[str, float], optional): Adjustments to the logits of specified tokens.
    • user (str, optional): An identifier for the user making the request.
    • (Additional model-specific parameters)

Dict[str, Any]: A dictionary with the chat completion response or an error message.

Example CURL Request:

curl -v -X POST "http://localhost:3000/api/v1/complete_vllm"             -H "Content-Type: application/json"             -u "user:password"             -d '{
        "messages": ["Whats the weather like in London?"],
        "temperature": 0.7,
        "top_p": 1.0,
        "n": 1,
        "max_tokens": 50,
        "stream": false,
        "presence_penalty": 0.0,
        "frequency_penalty": 0.0,
        "logit_bias": {},
        "user": "example_user"
    }'
This request asks the VLLM engine to generate a completion for the provided chat context, with specified generation settings.