MultiLoRA Inference#

Source vllm-project/vllm.

  1# flake8: noqa
  2# UPSTREAM SYNC: noqa is required for passing ruff run on nm-automation
  3"""
  4This example shows how to use the multi-LoRA functionality
  5for offline inference.
  6
  7Requires HuggingFace credentials for access to Llama2.
  8"""
  9
 10from typing import List, Optional, Tuple
 11
 12from huggingface_hub import snapshot_download
 13
 14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
 15from vllm.lora.request import LoRARequest
 16
 17
 18def create_test_prompts(
 19        lora_path: str
 20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
 21    """Create a list of test prompts with their sampling parameters.
 22
 23    2 requests for base model, 4 requests for the LoRA. We define 2
 24    different LoRA adapters (using the same model for demo purposes).
 25    Since we also set `max_loras=1`, the expectation is that the requests
 26    with the second LoRA adapter will be ran after all requests with the
 27    first adapter have finished.
 28    """
 29    return [
 30        ("A robot may not injure a human being",
 31         SamplingParams(temperature=0.0,
 32                        logprobs=1,
 33                        prompt_logprobs=1,
 34                        max_tokens=128), None),
 35        ("To be or not to be,",
 36         SamplingParams(temperature=0.8,
 37                        top_k=5,
 38                        presence_penalty=0.2,
 39                        max_tokens=128), None),
 40        (
 41            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
 42            SamplingParams(temperature=0.0,
 43                           logprobs=1,
 44                           prompt_logprobs=1,
 45                           max_tokens=128,
 46                           stop_token_ids=[32003]),
 47            LoRARequest("sql-lora", 1, lora_path)),
 48        (
 49            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
 50            SamplingParams(n=3,
 51                           best_of=3,
 52                           use_beam_search=True,
 53                           temperature=0,
 54                           max_tokens=128,
 55                           stop_token_ids=[32003]),
 56            LoRARequest("sql-lora", 1, lora_path)),
 57        (
 58            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
 59            SamplingParams(temperature=0.0,
 60                           logprobs=1,
 61                           prompt_logprobs=1,
 62                           max_tokens=128,
 63                           stop_token_ids=[32003]),
 64            LoRARequest("sql-lora2", 2, lora_path)),
 65        (
 66            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
 67            SamplingParams(n=3,
 68                           best_of=3,
 69                           use_beam_search=True,
 70                           temperature=0,
 71                           max_tokens=128,
 72                           stop_token_ids=[32003]),
 73            LoRARequest("sql-lora", 1, lora_path)),
 74    ]
 75
 76
 77def process_requests(engine: LLMEngine,
 78                     test_prompts: List[Tuple[str, SamplingParams,
 79                                              Optional[LoRARequest]]]):
 80    """Continuously process a list of prompts and handle the outputs."""
 81    request_id = 0
 82
 83    while test_prompts or engine.has_unfinished_requests():
 84        if test_prompts:
 85            prompt, sampling_params, lora_request = test_prompts.pop(0)
 86            engine.add_request(str(request_id),
 87                               prompt,
 88                               sampling_params,
 89                               lora_request=lora_request)
 90            request_id += 1
 91
 92        request_outputs: List[RequestOutput] = engine.step()
 93
 94        for request_output in request_outputs:
 95            if request_output.finished:
 96                print(request_output)
 97
 98
 99def initialize_engine() -> LLMEngine:
100    """Initialize the LLMEngine."""
101    # max_loras: controls the number of LoRAs that can be used in the same
102    #   batch. Larger numbers will cause higher memory usage, as each LoRA
103    #   slot requires its own preallocated tensor.
104    # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
105    #   numbers will cause higher memory usage. If you know that all LoRAs will
106    #   use the same rank, it is recommended to set this as low as possible.
107    # max_cpu_loras: controls the size of the CPU LoRA cache.
108    engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
109                             enable_lora=True,
110                             max_loras=1,
111                             max_lora_rank=8,
112                             max_cpu_loras=2,
113                             max_num_seqs=256)
114    return LLMEngine.from_engine_args(engine_args)
115
116
117def main():
118    """Main function that sets up and runs the prompt processing."""
119    engine = initialize_engine()
120    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
121    test_prompts = create_test_prompts(lora_path)
122    process_requests(engine, test_prompts)
123
124
125if __name__ == '__main__':
126    main()