MultiLoRA Inference#
Source vllm-project/vllm.
1# flake8: noqa
2# UPSTREAM SYNC: noqa is required for passing ruff run on nm-automation
3"""
4This example shows how to use the multi-LoRA functionality
5for offline inference.
6
7Requires HuggingFace credentials for access to Llama2.
8"""
9
10from typing import List, Optional, Tuple
11
12from huggingface_hub import snapshot_download
13
14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
15from vllm.lora.request import LoRARequest
16
17
18def create_test_prompts(
19 lora_path: str
20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
21 """Create a list of test prompts with their sampling parameters.
22
23 2 requests for base model, 4 requests for the LoRA. We define 2
24 different LoRA adapters (using the same model for demo purposes).
25 Since we also set `max_loras=1`, the expectation is that the requests
26 with the second LoRA adapter will be ran after all requests with the
27 first adapter have finished.
28 """
29 return [
30 ("A robot may not injure a human being",
31 SamplingParams(temperature=0.0,
32 logprobs=1,
33 prompt_logprobs=1,
34 max_tokens=128), None),
35 ("To be or not to be,",
36 SamplingParams(temperature=0.8,
37 top_k=5,
38 presence_penalty=0.2,
39 max_tokens=128), None),
40 (
41 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
42 SamplingParams(temperature=0.0,
43 logprobs=1,
44 prompt_logprobs=1,
45 max_tokens=128,
46 stop_token_ids=[32003]),
47 LoRARequest("sql-lora", 1, lora_path)),
48 (
49 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
50 SamplingParams(n=3,
51 best_of=3,
52 use_beam_search=True,
53 temperature=0,
54 max_tokens=128,
55 stop_token_ids=[32003]),
56 LoRARequest("sql-lora", 1, lora_path)),
57 (
58 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
59 SamplingParams(temperature=0.0,
60 logprobs=1,
61 prompt_logprobs=1,
62 max_tokens=128,
63 stop_token_ids=[32003]),
64 LoRARequest("sql-lora2", 2, lora_path)),
65 (
66 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
67 SamplingParams(n=3,
68 best_of=3,
69 use_beam_search=True,
70 temperature=0,
71 max_tokens=128,
72 stop_token_ids=[32003]),
73 LoRARequest("sql-lora", 1, lora_path)),
74 ]
75
76
77def process_requests(engine: LLMEngine,
78 test_prompts: List[Tuple[str, SamplingParams,
79 Optional[LoRARequest]]]):
80 """Continuously process a list of prompts and handle the outputs."""
81 request_id = 0
82
83 while test_prompts or engine.has_unfinished_requests():
84 if test_prompts:
85 prompt, sampling_params, lora_request = test_prompts.pop(0)
86 engine.add_request(str(request_id),
87 prompt,
88 sampling_params,
89 lora_request=lora_request)
90 request_id += 1
91
92 request_outputs: List[RequestOutput] = engine.step()
93
94 for request_output in request_outputs:
95 if request_output.finished:
96 print(request_output)
97
98
99def initialize_engine() -> LLMEngine:
100 """Initialize the LLMEngine."""
101 # max_loras: controls the number of LoRAs that can be used in the same
102 # batch. Larger numbers will cause higher memory usage, as each LoRA
103 # slot requires its own preallocated tensor.
104 # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
105 # numbers will cause higher memory usage. If you know that all LoRAs will
106 # use the same rank, it is recommended to set this as low as possible.
107 # max_cpu_loras: controls the size of the CPU LoRA cache.
108 engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
109 enable_lora=True,
110 max_loras=1,
111 max_lora_rank=8,
112 max_cpu_loras=2,
113 max_num_seqs=256)
114 return LLMEngine.from_engine_args(engine_args)
115
116
117def main():
118 """Main function that sets up and runs the prompt processing."""
119 engine = initialize_engine()
120 lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
121 test_prompts = create_test_prompts(lora_path)
122 process_requests(engine, test_prompts)
123
124
125if __name__ == '__main__':
126 main()