Language Model¶
Bases: TextAPI
LanguageModelAPI is a class for interacting with pre-trained language models to generate text. It allows for customizable text generation via a CherryPy web server, handling requests and generating responses using a specified language model. This class is part of the GeniusRise ecosystem for facilitating NLP tasks.
Attributes:
Name | Type | Description |
---|---|---|
model |
Any
|
The loaded language model used for text generation. |
tokenizer |
Any
|
The tokenizer corresponding to the language model, used for processing input text. |
Methods
complete(**kwargs: Any) -> Dict[str, Any]: Generates text based on provided prompts and model parameters.
CLI Usage Example:
genius LanguageModelAPI rise \
batch \
--input_folder ./input \
batch \
--output_folder ./output \
none \
--id mistralai/Mistral-7B-v0.1-lol \
listen \
--args \
model_name="mistralai/Mistral-7B-v0.1" \
model_class="AutoModelForCausalLM" \
tokenizer_class="AutoTokenizer" \
use_cuda=True \
precision="float16" \
quantization=0 \
device_map="auto" \
max_memory=None \
torchscript=False \
endpoint="*" \
port=3000 \
cors_domain="http://localhost:3000" \
username="user" \
password="password"
or using VLLM:
genius LanguageModelAPI rise \
batch \
--input_folder ./input \
batch \
--output_folder ./output \
none \
--id mistralai/Mistral-7B-v0.1 \
listen \
--args \
model_name="mistralai/Mistral-7B-v0.1" \
model_class="AutoModelForCausalLM" \
tokenizer_class="AutoTokenizer" \
use_cuda=True \
precision="bfloat16" \
use_vllm=True \
vllm_enforce_eager=True \
vllm_max_model_len=2048 \
concurrent_queries=False \
endpoint="*" \
port=3000 \
cors_domain="http://localhost:3000" \
username="user" \
password="password"
or using llama.cpp:
genius LanguageModelAPI rise \
batch \
--input_folder ./input \
batch \
--output_folder ./output \
none \
listen \
--args \
model_name="TheBloke/Mistral-7B-v0.1-GGUF" \
model_class="AutoModelForCausalLM" \
tokenizer_class="AutoTokenizer" \
use_cuda=True \
use_llama_cpp=True \
llama_cpp_filename="mistral-7b-v0.1.Q4_K_M.gguf" \
llama_cpp_n_gpu_layers=35 \
llama_cpp_n_ctx=32768 \
concurrent_queries=False \
endpoint="*" \
port=3000 \
cors_domain="http://localhost:3000" \
username="user" \
password="password"
__init__(input, output, state, **kwargs)
¶
Initializes the LanguageModelAPI with configurations for the input, output, and state management, along with any additional model-specific parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
BatchInput
|
The configuration for input data handling. |
required |
output |
BatchOutput
|
The configuration for output data handling. |
required |
state |
State
|
The state management for the API. |
required |
**kwargs |
Any
|
Additional keyword arguments for model configuration and API setup. |
{}
|
complete(**kwargs)
¶
Handles POST requests to generate text based on a given prompt and model-specific parameters. This method is exposed as a web endpoint through CherryPy and returns a JSON response containing the original prompt, the generated text, and any additional returned information from the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any
|
Arbitrary keyword arguments containing the prompt, and any additional parameters |
{}
|
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dict[str, Any]: A dictionary with the original prompt, generated text, and other model-specific information. |
Example CURL Request:
/usr/bin/curl -X POST localhost:3000/api/v1/complete \
-H "Content-Type: application/json" \
-d '{
"prompt": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWrite a PRD for Oauth auth using keycloak\n\n### Response:",
"decoding_strategy": "generate",
"max_new_tokens": 1024,
"do_sample": true
}' | jq
complete_llama_cpp(**kwargs)
¶
Handles POST requests to generate chat completions using the llama.cpp engine. This method accepts various parameters for customizing the chat completion request, including messages, sampling settings, and more.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prompt |
The prompt to generate text from. |
required | |
suffix |
A suffix to append to the generated text. If None, no suffix is appended. |
required | |
max_tokens |
The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. |
required | |
temperature |
The temperature to use for sampling. |
required | |
top_p |
The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 |
required | |
min_p |
The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 |
required | |
typical_p |
The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. |
required | |
logprobs |
The number of logprobs to return. If None, no logprobs are returned. |
required | |
echo |
Whether to echo the prompt. |
required | |
stop |
A list of strings to stop generation when encountered. |
required | |
frequency_penalty |
The penalty to apply to tokens based on their frequency in the prompt. |
required | |
presence_penalty |
The penalty to apply to tokens based on their presence in the prompt. |
required | |
repeat_penalty |
The penalty to apply to repeated tokens. |
required | |
top_k |
The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 |
required | |
stream |
Whether to stream the results. |
required | |
seed |
The seed to use for sampling. |
required | |
tfs_z |
The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. |
required | |
mirostat_mode |
The mirostat sampling mode. |
required | |
mirostat_tau |
The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. |
required | |
mirostat_eta |
The learning rate used to update |
required | |
model |
The name to use for the model in the completion object. |
required | |
stopping_criteria |
A list of stopping criteria to use. |
required | |
logits_processor |
A list of logits processors to use. |
required | |
grammar |
A grammar to use for constrained sampling. |
required | |
logit_bias |
A logit bias to use. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dict[str, Any]: A dictionary containing the chat completion response or an error message. |
Example CURL Request:
complete_vllm(**kwargs)
¶
Handles POST requests to generate chat completions using the VLLM (Versatile Language Learning Model) engine. This method accepts various parameters for customizing the chat completion request, including message content, generation settings, and more.
- **kwargs (Any): Arbitrary keyword arguments. Expects data in JSON format containing any of the following keys:
- messages (Union[str, List[Dict[str, str]]]): The messages for the chat context.
- temperature (float, optional): The sampling temperature. Defaults to 0.7.
- top_p (float, optional): The nucleus sampling probability. Defaults to 1.0.
- n (int, optional): The number of completions to generate. Defaults to 1.
- max_tokens (int, optional): The maximum number of tokens to generate.
- stop (Union[str, List[str]], optional): Stop sequence to end generation.
- stream (bool, optional): Whether to stream the response. Defaults to False.
- presence_penalty (float, optional): The presence penalty. Defaults to 0.0.
- frequency_penalty (float, optional): The frequency penalty. Defaults to 0.0.
- logit_bias (Dict[str, float], optional): Adjustments to the logits of specified tokens.
- user (str, optional): An identifier for the user making the request.
- (Additional model-specific parameters)
Dict[str, Any]: A dictionary with the chat completion response or an error message.
Example CURL Request:
curl -v -X POST "http://localhost:3000/api/v1/complete_vllm" -H "Content-Type: application/json" -u "user:password" -d '{
"messages": ["Whats the weather like in London?"],
"temperature": 0.7,
"top_p": 1.0,
"n": 1,
"max_tokens": 50,
"stream": false,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"logit_bias": {},
"user": "example_user"
}'