Offline Bench#

Source vllm-project/vllm.

  1import argparse
  2import random
  3import time
  4
  5from vllm import LLM, SamplingParams
  6
  7NUM_REQUESTS_DEFAULT = 256
  8MAX_SEQ_LEN_DEFAULT = 1024
  9MAX_TOKENS_DEFAULT = 128
 10SAMPLE_PROMPTS = [
 11    # "Hello, my name is",
 12    # "The president of the United States is",
 13    # "The capital of France is",
 14    "The future of AI is",
 15]
 16
 17
 18def run_bench(model_name,
 19              model_revision,
 20              is_sparse,
 21              quant_method,
 22              max_seq_len,
 23              max_tokens,
 24              num_requests,
 25              num_gpus,
 26              num_warmup_iters=1,
 27              num_bench_iters=5,
 28              possible_prompts=SAMPLE_PROMPTS,
 29              enforce_eager=True):
 30    print("Run bench with:")
 31    print(f"  model_name = {model_name}")
 32    print(f"  model_revision = {model_revision}")
 33    print(f"  is_sparse = {is_sparse}")
 34    print(f"  quant_method = {quant_method}")
 35    print(f"  max_seq_len = {max_seq_len}")
 36    print(f"  max_tokens = {max_tokens}")
 37    print(f"  num_requests = {num_requests}")
 38    print(f"  num_gpus = {num_gpus}")
 39    print(f"  num_warmup_iters = {num_warmup_iters}")
 40    print(f"  num_bench_iters = {num_bench_iters}")
 41
 42    prompts = []
 43    for _ in range(num_requests):
 44        index = random.randint(0, len(possible_prompts) - 1)
 45        prompts.append(possible_prompts[index])
 46
 47    # Create sampling params
 48    sampling_params = SamplingParams(temperature=0.8,
 49                                     top_p=0.95,
 50                                     max_tokens=max_tokens)
 51
 52    # Create LLM
 53    llm = LLM(
 54        model=model_name,
 55        revision=model_revision,
 56        sparsity="sparse_w16a16" if is_sparse else None,
 57        enforce_eager=enforce_eager,
 58        #   dtype=torch.bfloat16,
 59        tensor_parallel_size=num_gpus,
 60        gpu_memory_utilization=0.9,
 61        max_model_len=max_seq_len,
 62        quantization=quant_method,
 63    )
 64
 65    for i in range(num_warmup_iters):
 66        start_time = time.time()
 67        outputs = llm.generate(prompts, sampling_params)
 68        elapsed_time = time.time() - start_time
 69        print(f"Warmup iter {i} time: {elapsed_time} [secs]")
 70
 71    iter_times = []
 72    for i in range(num_bench_iters):
 73        start_time = time.time()
 74        outputs = llm.generate(prompts, sampling_params)
 75        iter_times.append(time.time() - start_time)
 76        print(f"Bench iter {i} time: {iter_times[-1]} [secs]")
 77
 78    average_iter_time = sum(iter_times) / num_bench_iters
 79    print(f"Average per iter time: {average_iter_time} [secs]")
 80
 81    # Print outputs of the last iter
 82    for output in outputs:
 83        prompt = output.prompt
 84        generated_text = output.outputs[0].text
 85        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 86
 87    return average_iter_time
 88
 89
 90if __name__ == "__main__":
 91    parser = argparse.ArgumentParser()
 92
 93    parser.add_argument("--model_name", type=str, required=True)
 94    parser.add_argument("--model_revision", type=str, default=None)
 95    parser.add_argument('--is_sparse', action='store_true')
 96    parser.add_argument("--quant_method", type=str, default=None)
 97    parser.add_argument("--max_seq_len", type=int, default=MAX_SEQ_LEN_DEFAULT)
 98    parser.add_argument("--max_tokens", type=int, default=MAX_TOKENS_DEFAULT)
 99    parser.add_argument("--num_requests",
100                        type=int,
101                        default=NUM_REQUESTS_DEFAULT)
102    parser.add_argument("--num_gpus", type=int, default=1)
103    parser.add_argument("--num_warmup_iters", type=int, default=1)
104    parser.add_argument("--num_bench_iters", type=int, default=5)
105
106    args = parser.parse_args()
107
108    run_bench(args.model_name, args.model_revision, args.is_sparse,
109              args.quant_method, args.max_seq_len, args.max_tokens,
110              args.num_requests, args.num_gpus, args.num_warmup_iters,
111              args.num_bench_iters)