Skip to content

vllm.entrypoints.grpc_server

vLLM gRPC Server

Starts a gRPC server for vLLM using the VllmEngine protocol.

Usage

python -m vllm.entrypoints.grpc_server --model

Example

python -m vllm.entrypoints.grpc_server --model meta-llama/Llama-2-7b-hf --host 0.0.0.0 --port 50051

VllmEngineServicer

Bases: VllmEngineServicer

gRPC servicer implementing the VllmEngine service.

Handles 6 RPCs: - Generate: Streaming text generation - Embed: Embeddings (TODO) - HealthCheck: Health probe - Abort: Cancel requests out-of-band - GetModelInfo: Model metadata - GetServerInfo: Server state

Source code in vllm/entrypoints/grpc_server.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
    """
    gRPC servicer implementing the VllmEngine service.

    Handles 6 RPCs:
    - Generate: Streaming text generation
    - Embed: Embeddings (TODO)
    - HealthCheck: Health probe
    - Abort: Cancel requests out-of-band
    - GetModelInfo: Model metadata
    - GetServerInfo: Server state
    """

    def __init__(self, async_llm: AsyncLLM, start_time: float):
        """
        Initialize the servicer.

        Args:
            async_llm: The AsyncLLM instance
            start_time: The server start time, in seconds since epoch
        """
        self.async_llm = async_llm
        self.start_time = start_time
        logger.info("VllmEngineServicer initialized")

    async def Generate(
        self,
        request: vllm_engine_pb2.GenerateRequest,
        context: grpc.aio.ServicerContext,
    ) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]:
        """
        Handle streaming generation requests.

        Supports n>1 by sending separate chunk/complete messages for each output index.
        When streaming with n>1, chunks for different indices are interleaved.

        Args:
            request: The GenerateRequest protobuf
            context: gRPC context

        Yields:
            GenerateResponse protobuf messages (streaming)
        """
        request_id = request.request_id
        input_type = request.WhichOneof("input")
        has_preprocessed_mm = request.HasField(
            "mm_inputs"
        ) and request.mm_inputs.HasField("pixel_values")
        logger.info(
            "Generate request %s: input_type=%s, stream=%s, preprocessed_mm=%s",
            request_id,
            input_type,
            request.stream,
            has_preprocessed_mm,
        )

        try:
            if has_preprocessed_mm and input_type == "tokenized":
                # Preprocessed multimodal from Rust router.
                # Token IDs already have expanded placeholders; tensors are
                # ready for the model. Bypass the renderer entirely.
                prompt = self._build_preprocessed_mm_inputs(
                    request.tokenized, request.mm_inputs
                )
                prompt["arrival_time"] = time.time()
            elif input_type == "tokenized":
                prompt: TokensPrompt = {
                    "prompt_token_ids": list(request.tokenized.input_ids)
                }
                if request.tokenized.original_text:
                    prompt["prompt"] = request.tokenized.original_text
                renderer = self.async_llm.input_processor.input_preprocessor.renderer
                prompt = renderer.process_for_engine(prompt, arrival_time=time.time())
            else:
                prompt: TextPrompt = {"prompt": request.text}
                prompt = self.async_llm.input_processor.input_preprocessor.preprocess(
                    prompt
                )

            # Validate kv_transfer_params if present
            if request.HasField("kv_transfer_params"):
                remote_host = request.kv_transfer_params.remote_host
                remote_port = request.kv_transfer_params.remote_port
                if not remote_host or remote_port == 0:
                    await context.abort(
                        grpc.StatusCode.INVALID_ARGUMENT,
                        "Invalid kv_transfer_params: "
                        "remote_host and remote_port must be set.",
                    )
                logger.info(
                    "Request %s: kv_transfer_params={remote_host=%s, remote_port=%d}",
                    request_id,
                    remote_host,
                    remote_port,
                )

            # Build sampling params with detokenize=False
            sampling_params = self._sampling_params_from_proto(
                request.sampling_params,
                stream=request.stream,
                kv_transfer_params=request.kv_transfer_params
                if request.HasField("kv_transfer_params")
                else None,
            )
            tokenization_kwargs = self._tokenization_kwargs_from_proto(
                request.sampling_params
            )

            # Extract logprobs configuration
            num_logprobs = sampling_params.logprobs
            num_prompt_logprobs = sampling_params.prompt_logprobs

            # Track which indices have sent their first chunk
            seen_indices: set[int] = set()

            async for output in self.async_llm.generate(
                prompt=prompt,
                sampling_params=sampling_params,
                request_id=request_id,
                tokenization_kwargs=tokenization_kwargs,
            ):
                # For streaming, send chunks for EACH completion output (n outputs)
                if request.stream:
                    for completion in output.outputs:
                        idx = completion.index
                        is_first = idx not in seen_indices
                        seen_indices.add(idx)

                        # Send chunk with delta data (Rust accumulates for vLLM)
                        yield self._chunk_response(
                            output,
                            completion=completion,
                            num_logprobs=num_logprobs,
                            num_prompt_logprobs=num_prompt_logprobs,
                            is_first_chunk=is_first,
                        )

                        # Send Complete when sequence finishes (n>1 support)
                        if completion.finish_reason:
                            yield self._complete_response(
                                output,
                                completion=completion,
                                num_logprobs=num_logprobs,
                                num_prompt_logprobs=num_prompt_logprobs,
                            )

                # For non-streaming, send complete response when finished
                if output.finished and not request.stream:
                    for completion in output.outputs:
                        yield self._complete_response(
                            output,
                            completion=completion,
                            num_logprobs=num_logprobs,
                            num_prompt_logprobs=num_prompt_logprobs,
                        )

        except ValueError as e:
            # Invalid request error (equiv to 400).
            await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
        except Exception as e:
            logger.exception("Error in Generate for request %s", request_id)
            await context.abort(grpc.StatusCode.INTERNAL, str(e))

    async def Embed(
        self,
        request: vllm_engine_pb2.EmbedRequest,
        context: grpc.aio.ServicerContext,
    ) -> vllm_engine_pb2.EmbedResponse:
        """
        Handle embedding requests.

        TODO: Implement in Phase 4

        Args:
            request: The EmbedRequest protobuf
            context: gRPC context

        Returns:
            EmbedResponse protobuf
        """
        logger.warning("Embed RPC not yet implemented")
        await context.abort(
            grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented"
        )

    async def HealthCheck(
        self,
        request: vllm_engine_pb2.HealthCheckRequest,
        context: grpc.aio.ServicerContext,
    ) -> vllm_engine_pb2.HealthCheckResponse:
        """
        Handle health check requests.

        Args:
            request: The HealthCheckRequest protobuf
            context: gRPC context

        Returns:
            HealthCheckResponse protobuf
        """
        is_healthy = not self.async_llm.errored
        message = "Health" if is_healthy else "Engine is not alive"

        logger.info("HealthCheck request: healthy=%s, message=%s", is_healthy, message)

        return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message)

    async def Abort(
        self,
        request: vllm_engine_pb2.AbortRequest,
        context: grpc.aio.ServicerContext,
    ) -> vllm_engine_pb2.AbortResponse:
        """
        Out-of-band abort requests.

        Args:
            request: The AbortRequest protobuf
            context: gRPC context

        Returns:
            AbortResponse protobuf
        """
        request_ids = request.request_ids
        logger.info("Abort requests: %s", request_ids)

        await self.async_llm.abort(request_ids)
        return vllm_engine_pb2.AbortResponse()

    async def GetModelInfo(
        self,
        request: vllm_engine_pb2.GetModelInfoRequest,
        context: grpc.aio.ServicerContext,
    ) -> vllm_engine_pb2.GetModelInfoResponse:
        """
        Handle model info requests.

        Args:
            request: The GetModelInfoRequest protobuf
            context: gRPC context

        Returns:
            GetModelInfoResponse protobuf
        """
        model_config = self.async_llm.model_config

        return vllm_engine_pb2.GetModelInfoResponse(
            model_path=model_config.model,
            is_generation=model_config.runner_type == "generate",
            max_context_length=model_config.max_model_len,
            vocab_size=model_config.get_vocab_size(),
            supports_vision=model_config.is_multimodal_model,
        )

    async def GetServerInfo(
        self,
        request: vllm_engine_pb2.GetServerInfoRequest,
        context: grpc.aio.ServicerContext,
    ) -> vllm_engine_pb2.GetServerInfoResponse:
        """
        Handle server info requests.

        Args:
            request: The GetServerInfoRequest protobuf
            context: gRPC context

        Returns:
            GetServerInfoResponse protobuf
        """
        num_requests = self.async_llm.output_processor.get_num_unfinished_requests()

        # Get KV transfer config if available
        kv_connector = ""
        kv_role = ""
        kv_transfer_config = self.async_llm.vllm_config.kv_transfer_config
        if kv_transfer_config is not None:
            kv_connector = kv_transfer_config.kv_connector or ""
            kv_role = kv_transfer_config.kv_role or ""

        return vllm_engine_pb2.GetServerInfoResponse(
            active_requests=num_requests,
            is_paused=False,  # TODO
            last_receive_timestamp=time.time(),  # TODO looks wrong?
            uptime_seconds=time.time() - self.start_time,
            server_type="vllm-grpc",
            kv_connector=kv_connector,
            kv_role=kv_role,
        )

    # ========== Helper methods ==========

    def _build_preprocessed_mm_inputs(
        self,
        tokenized: vllm_engine_pb2.TokenizedInput,
        mm_proto: vllm_engine_pb2.MultimodalInputs,
    ) -> VllmMultiModalInputs:
        """Build vLLM MultiModalInputs from preprocessed proto data.

        Bypasses HF processor entirely — pixel values and model-specific
        tensors were already computed by the Rust router.  Field layouts
        (batched / flat / shared) are also determined by the router via
        ``batched_keys`` and ``flat_keys`` proto fields.
        """
        prompt_token_ids = list(tokenized.input_ids)
        num_images = len(mm_proto.mm_placeholders)

        # Deserialize all tensors from proto
        hf_dict: dict[str, torch.Tensor] = {
            "pixel_values": _tensor_from_proto(mm_proto.pixel_values),
        }
        for key, td in mm_proto.model_specific_tensors.items():
            hf_dict[key] = _tensor_from_proto(td)

        # Cast floating-point tensors to model dtype (e.g. bfloat16).
        # This mirrors _postprocess_output in multimodal/processing/context.py
        # which is skipped when bypassing the HF processor.
        model_dtype = self.async_llm.model_config.dtype
        for key in hf_dict:
            if hf_dict[key].is_floating_point():
                hf_dict[key] = hf_dict[key].to(dtype=model_dtype)

        cpu_keys = set(mm_proto.keep_on_cpu_keys)

        # Field configs are fully determined by the Rust router.
        batched = set(mm_proto.batched_keys)
        flat = dict(mm_proto.flat_keys)
        fields_config: dict[str, MultiModalFieldConfig] = {}
        for key in hf_dict:
            on_cpu = key in cpu_keys
            if key in batched:
                fields_config[key] = MultiModalFieldConfig.batched(
                    "image", keep_on_cpu=on_cpu
                )
            elif key in flat:
                sizes = hf_dict[flat[key]].flatten().to(torch.int64)
                fields_config[key] = MultiModalFieldConfig.flat_from_sizes(
                    "image", sizes, keep_on_cpu=on_cpu
                )
            else:
                fields_config[key] = MultiModalFieldConfig.shared("image", num_images)

        batch_feature = BatchFeature(hf_dict, tensor_type="pt")
        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(batch_feature, fields_config)

        # Build mm_hashes: dict[str, list[str]]
        mm_hashes: dict[str, list[str]] = {}
        if mm_proto.mm_hashes:
            mm_hashes["image"] = list(mm_proto.mm_hashes)

        # Build mm_placeholders: dict[str, list[PlaceholderRange]]
        # When structural tokens (e.g. <|image_start|>, separators) are present
        # in the placeholder range, we must set is_embed so vLLM only scatters
        # encoder embeddings into patch-token positions (im_token_id).
        mm_placeholders: dict[str, list[PlaceholderRange]] = {}
        if mm_proto.mm_placeholders:
            im_token_id = (
                mm_proto.im_token_id if mm_proto.HasField("im_token_id") else None
            )
            placeholders = []
            for p in mm_proto.mm_placeholders:
                is_embed = None
                if im_token_id is not None:
                    token_slice = prompt_token_ids[p.offset : p.offset + p.length]
                    mask = [t == im_token_id for t in token_slice]
                    # Only set is_embed when there are non-embed positions,
                    # otherwise None means "all positions are embeds" which is
                    # both correct and avoids unnecessary overhead.
                    if not all(mask):
                        is_embed = torch.tensor(mask, dtype=torch.bool)
                placeholders.append(
                    PlaceholderRange(
                        offset=p.offset, length=p.length, is_embed=is_embed
                    )
                )
            mm_placeholders["image"] = placeholders

        return mm_inputs(
            prompt_token_ids=prompt_token_ids,
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
            mm_placeholders=mm_placeholders,
            prompt=tokenized.original_text or None,
        )

    @staticmethod
    def _sampling_params_from_proto(
        params: vllm_engine_pb2.SamplingParams,
        stream: bool = True,
        kv_transfer_params: vllm_engine_pb2.KvTransferParams | None = None,
    ) -> SamplingParams:
        """
        Convert protobuf SamplingParams to vLLM SamplingParams.

        Args:
            params: Protobuf SamplingParams message
            stream: Whether streaming is enabled
            kv_transfer_params: KV transfer params proto for Mooncake PD

        Returns:
            vLLM SamplingParams with detokenize=False and structured_outputs
        """
        # Build stop sequences
        stop = list(params.stop) if params.stop else None
        stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None

        # Handle structured outputs constraints
        structured_outputs = None
        constraint_field = params.WhichOneof("constraint")
        if constraint_field:
            if constraint_field == "json_schema":
                structured_outputs = StructuredOutputsParams(json=params.json_schema)
            elif constraint_field == "regex":
                structured_outputs = StructuredOutputsParams(regex=params.regex)
            elif constraint_field == "grammar":
                structured_outputs = StructuredOutputsParams(grammar=params.grammar)
            elif constraint_field == "structural_tag":
                structured_outputs = StructuredOutputsParams(
                    structural_tag=params.structural_tag
                )
            elif constraint_field == "json_object":
                structured_outputs = StructuredOutputsParams(
                    json_object=params.json_object
                )
            elif constraint_field == "choice":
                structured_outputs = StructuredOutputsParams(
                    choice=list(params.choice.choices)
                )

        # Build extra_args for kv_transfer_params (Mooncake PD)
        extra_args = None
        if kv_transfer_params:
            extra_args = {
                "kv_transfer_params": {
                    "remote_host": kv_transfer_params.remote_host,
                    "remote_port": kv_transfer_params.remote_port,
                }
            }

        # Create SamplingParams
        # output_kind=DELTA: Return only new tokens in each chunk (for streaming)
        return SamplingParams(
            temperature=params.temperature if params.HasField("temperature") else 1.0,
            top_p=params.top_p if params.top_p != 0.0 else 1.0,
            top_k=params.top_k,
            min_p=params.min_p,
            frequency_penalty=params.frequency_penalty,
            presence_penalty=params.presence_penalty,
            repetition_penalty=params.repetition_penalty
            if params.repetition_penalty != 0.0
            else 1.0,
            max_tokens=params.max_tokens if params.HasField("max_tokens") else None,
            min_tokens=params.min_tokens,
            stop=stop,
            stop_token_ids=stop_token_ids,
            skip_special_tokens=params.skip_special_tokens,
            spaces_between_special_tokens=params.spaces_between_special_tokens,
            ignore_eos=params.ignore_eos,
            n=params.n if params.n > 0 else 1,
            logprobs=params.logprobs if params.HasField("logprobs") else None,
            prompt_logprobs=params.prompt_logprobs
            if params.HasField("prompt_logprobs")
            else None,
            seed=params.seed if params.HasField("seed") else None,
            include_stop_str_in_output=params.include_stop_str_in_output,
            logit_bias=dict(params.logit_bias) if params.logit_bias else None,
            structured_outputs=structured_outputs,
            extra_args=extra_args,
            # detokenize must be True if stop strings are used
            detokenize=bool(stop),
            output_kind=RequestOutputKind.DELTA
            if stream
            else RequestOutputKind.FINAL_ONLY,
        )

    @staticmethod
    def _build_top_logprobs(
        logprob_entry: dict,
        num_top_logprobs: int | None,
    ) -> vllm_engine_pb2.TopLogProbs:
        """Build TopLogProbs proto from a logprob entry dict."""
        top = vllm_engine_pb2.TopLogProbs()
        if num_top_logprobs and logprob_entry:
            sorted_entries = sorted(
                logprob_entry.items(),
                key=lambda x: x[1].logprob,
                reverse=True,
            )
            for tid, lp in functools.islice(sorted_entries, num_top_logprobs):
                top.token_ids.append(tid)
                top.values.append(lp.logprob)
        return top

    @staticmethod
    def _build_output_logprobs(
        logprobs: SampleLogprobs | None,
        token_ids: list[int],
        num_top_logprobs: int | None,
    ) -> vllm_engine_pb2.OutputLogProbs | None:
        """
        Convert vLLM SampleLogprobs to proto OutputLogProbs.

        Args:
            logprobs: vLLM logprobs (list of dict[int, Logprob])
            token_ids: Token IDs for each position
            num_top_logprobs: Number of top logprobs to include

        Returns:
            OutputLogProbs proto or None
        """
        if not logprobs:
            return None

        proto = vllm_engine_pb2.OutputLogProbs()

        for token_id, logprob_entry in zip(token_ids, logprobs):
            if logprob := logprob_entry.get(token_id):
                proto.token_logprobs.append(logprob.logprob)
                proto.token_ids.append(token_id)

                if num_top_logprobs:
                    proto.top_logprobs.append(
                        VllmEngineServicer._build_top_logprobs(
                            logprob_entry, num_top_logprobs
                        )
                    )

        return proto if proto.token_ids else None

    @staticmethod
    def _build_input_logprobs(
        prompt_logprobs: PromptLogprobs | None,
        prompt_token_ids: list[int],
        num_top_logprobs: int | None,
    ) -> vllm_engine_pb2.InputLogProbs | None:
        """
        Convert vLLM PromptLogprobs to proto InputLogProbs.

        Args:
            prompt_logprobs: vLLM prompt logprobs (list of dict[int, Logprob] | None)
            prompt_token_ids: Prompt token IDs
            num_top_logprobs: Number of top logprobs to include

        Returns:
            InputLogProbs proto or None
        """
        if not prompt_logprobs:
            return None

        proto = vllm_engine_pb2.InputLogProbs()

        for token_id, logprob_entry in zip(prompt_token_ids, prompt_logprobs):
            token_logprob = vllm_engine_pb2.InputTokenLogProb()

            # First token has no logprob (None)
            if logprob_entry is not None and token_id in logprob_entry:
                token_logprob.value = logprob_entry[token_id].logprob

            proto.token_logprobs.append(token_logprob)
            proto.token_ids.append(token_id)
            proto.top_logprobs.append(
                VllmEngineServicer._build_top_logprobs(logprob_entry, num_top_logprobs)
            )

        return proto if proto.token_ids else None

    @staticmethod
    def _tokenization_kwargs_from_proto(
        params: vllm_engine_pb2.SamplingParams,
    ) -> dict[str, int] | None:
        if params.HasField("truncate_prompt_tokens"):
            return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
        return None

    @staticmethod
    def _chunk_response(
        output: RequestOutput,
        completion: "CompletionOutput | None" = None,
        num_logprobs: int | None = None,
        num_prompt_logprobs: int | None = None,
        is_first_chunk: bool = False,
    ) -> vllm_engine_pb2.GenerateResponse:
        """
        Build a streaming chunk response from vLLM output.
        When output_kind=DELTA, vLLM returns only new tokens automatically.

        Note: This sends DELTA logprobs (only for new tokens in this chunk).
        The Rust side is responsible for accumulating if needed.

        Args:
            output: vLLM RequestOutput (with delta tokens when output_kind=DELTA)
            completion: Specific CompletionOutput to use (for n>1 support).
                       If None, uses output.outputs[0] for backwards compatibility.
            num_logprobs: Number of top logprobs for output tokens
            num_prompt_logprobs: Number of top logprobs for prompt tokens
            is_first_chunk: Whether this is the first chunk for this index
                           (include input_logprobs only on first chunk)

        Returns:
            GenerateResponse with chunk field set
        """
        # Use provided completion or fall back to first output
        if completion is None:
            completion = output.outputs[0] if output.outputs else None

        if completion is None:
            # Empty chunk
            return vllm_engine_pb2.GenerateResponse(
                chunk=vllm_engine_pb2.GenerateStreamChunk(
                    token_ids=[],
                    prompt_tokens=0,
                    completion_tokens=0,
                    cached_tokens=0,
                    index=0,
                ),
            )

        # Build output logprobs for this chunk's tokens (delta, not cumulative)
        output_logprobs = VllmEngineServicer._build_output_logprobs(
            completion.logprobs, completion.token_ids, num_logprobs
        )

        # Build input logprobs only on first chunk for this index
        input_logprobs = None
        if is_first_chunk:
            input_logprobs = VllmEngineServicer._build_input_logprobs(
                output.prompt_logprobs,
                output.prompt_token_ids,
                num_prompt_logprobs,
            )

        # When output_kind=DELTA, completion.token_ids contains only new tokens
        # vLLM handles the delta logic internally
        # completion_tokens = delta count (client will accumulate)
        return vllm_engine_pb2.GenerateResponse(
            chunk=vllm_engine_pb2.GenerateStreamChunk(
                token_ids=completion.token_ids,
                prompt_tokens=len(output.prompt_token_ids)
                if output.prompt_token_ids
                else 0,
                completion_tokens=len(completion.token_ids),  # Delta count
                cached_tokens=output.num_cached_tokens,
                output_logprobs=output_logprobs,
                input_logprobs=input_logprobs,
                index=completion.index,
            ),
        )

    @staticmethod
    def _complete_response(
        output: RequestOutput,
        completion: "CompletionOutput | None" = None,
        num_logprobs: int | None = None,
        num_prompt_logprobs: int | None = None,
    ) -> vllm_engine_pb2.GenerateResponse:
        """
        Build a final completion response from vLLM output.

        For non-streaming (FINAL_ONLY): completion has all tokens and logprobs.
        For streaming (DELTA): completion has last delta; Rust accumulates.

        Args:
            output: vLLM RequestOutput (finished=True)
            completion: Specific CompletionOutput to use (for n>1 support).
                       If None, uses output.outputs[0] for backwards compatibility.
            num_logprobs: Number of top logprobs for output tokens
            num_prompt_logprobs: Number of top logprobs for prompt tokens

        Returns:
            GenerateResponse with complete field set
        """
        # Use provided completion or fall back to first output
        if completion is None:
            completion = output.outputs[0] if output.outputs else None

        if completion is None:
            # Empty completion
            return vllm_engine_pb2.GenerateResponse(
                complete=vllm_engine_pb2.GenerateComplete(
                    output_ids=[],
                    finish_reason="error",
                    prompt_tokens=0,
                    completion_tokens=0,
                    cached_tokens=0,
                    index=0,
                ),
            )

        # Build output logprobs from completion's data
        # For non-streaming: this has all logprobs
        # For streaming: this has only last delta (Rust accumulates from chunks)
        output_logprobs = VllmEngineServicer._build_output_logprobs(
            completion.logprobs, completion.token_ids, num_logprobs
        )

        # Build input logprobs
        input_logprobs = VllmEngineServicer._build_input_logprobs(
            output.prompt_logprobs,
            output.prompt_token_ids,
            num_prompt_logprobs,
        )

        # Build kv_transfer_params if present (Mooncake PD)
        kv_transfer_params = None
        if output.kv_transfer_params:
            kv_transfer_params = vllm_engine_pb2.KvTransferParams(
                remote_host=output.kv_transfer_params.get("remote_host", ""),
                remote_port=output.kv_transfer_params.get("remote_port", 0),
            )

        # Build complete response
        # When streaming (DELTA mode): completion.token_ids will be empty/last delta
        # When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens
        # Client will accumulate token counts for streaming
        return vllm_engine_pb2.GenerateResponse(
            complete=vllm_engine_pb2.GenerateComplete(
                output_ids=completion.token_ids,
                finish_reason=completion.finish_reason or "stop",
                prompt_tokens=len(output.prompt_token_ids)
                if output.prompt_token_ids
                else 0,
                completion_tokens=len(completion.token_ids),
                cached_tokens=output.num_cached_tokens,
                output_logprobs=output_logprobs,
                input_logprobs=input_logprobs,
                index=completion.index,
                kv_transfer_params=kv_transfer_params,
            ),
        )

Abort async

Abort(
    request: AbortRequest, context: ServicerContext
) -> AbortResponse

Out-of-band abort requests.

Parameters:

Name Type Description Default
request AbortRequest

The AbortRequest protobuf

required
context ServicerContext

gRPC context

required

Returns:

Type Description
AbortResponse

AbortResponse protobuf

Source code in vllm/entrypoints/grpc_server.py
async def Abort(
    self,
    request: vllm_engine_pb2.AbortRequest,
    context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.AbortResponse:
    """
    Out-of-band abort requests.

    Args:
        request: The AbortRequest protobuf
        context: gRPC context

    Returns:
        AbortResponse protobuf
    """
    request_ids = request.request_ids
    logger.info("Abort requests: %s", request_ids)

    await self.async_llm.abort(request_ids)
    return vllm_engine_pb2.AbortResponse()

Embed async

Embed(
    request: EmbedRequest, context: ServicerContext
) -> EmbedResponse

Handle embedding requests.

TODO: Implement in Phase 4

Parameters:

Name Type Description Default
request EmbedRequest

The EmbedRequest protobuf

required
context ServicerContext

gRPC context

required

Returns:

Type Description
EmbedResponse

EmbedResponse protobuf

Source code in vllm/entrypoints/grpc_server.py
async def Embed(
    self,
    request: vllm_engine_pb2.EmbedRequest,
    context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.EmbedResponse:
    """
    Handle embedding requests.

    TODO: Implement in Phase 4

    Args:
        request: The EmbedRequest protobuf
        context: gRPC context

    Returns:
        EmbedResponse protobuf
    """
    logger.warning("Embed RPC not yet implemented")
    await context.abort(
        grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented"
    )

Generate async

Generate(
    request: GenerateRequest, context: ServicerContext
) -> AsyncGenerator[GenerateResponse, None]

Handle streaming generation requests.

Supports n>1 by sending separate chunk/complete messages for each output index. When streaming with n>1, chunks for different indices are interleaved.

Parameters:

Name Type Description Default
request GenerateRequest

The GenerateRequest protobuf

required
context ServicerContext

gRPC context

required

Yields:

Type Description
AsyncGenerator[GenerateResponse, None]

GenerateResponse protobuf messages (streaming)

Source code in vllm/entrypoints/grpc_server.py
async def Generate(
    self,
    request: vllm_engine_pb2.GenerateRequest,
    context: grpc.aio.ServicerContext,
) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]:
    """
    Handle streaming generation requests.

    Supports n>1 by sending separate chunk/complete messages for each output index.
    When streaming with n>1, chunks for different indices are interleaved.

    Args:
        request: The GenerateRequest protobuf
        context: gRPC context

    Yields:
        GenerateResponse protobuf messages (streaming)
    """
    request_id = request.request_id
    input_type = request.WhichOneof("input")
    has_preprocessed_mm = request.HasField(
        "mm_inputs"
    ) and request.mm_inputs.HasField("pixel_values")
    logger.info(
        "Generate request %s: input_type=%s, stream=%s, preprocessed_mm=%s",
        request_id,
        input_type,
        request.stream,
        has_preprocessed_mm,
    )

    try:
        if has_preprocessed_mm and input_type == "tokenized":
            # Preprocessed multimodal from Rust router.
            # Token IDs already have expanded placeholders; tensors are
            # ready for the model. Bypass the renderer entirely.
            prompt = self._build_preprocessed_mm_inputs(
                request.tokenized, request.mm_inputs
            )
            prompt["arrival_time"] = time.time()
        elif input_type == "tokenized":
            prompt: TokensPrompt = {
                "prompt_token_ids": list(request.tokenized.input_ids)
            }
            if request.tokenized.original_text:
                prompt["prompt"] = request.tokenized.original_text
            renderer = self.async_llm.input_processor.input_preprocessor.renderer
            prompt = renderer.process_for_engine(prompt, arrival_time=time.time())
        else:
            prompt: TextPrompt = {"prompt": request.text}
            prompt = self.async_llm.input_processor.input_preprocessor.preprocess(
                prompt
            )

        # Validate kv_transfer_params if present
        if request.HasField("kv_transfer_params"):
            remote_host = request.kv_transfer_params.remote_host
            remote_port = request.kv_transfer_params.remote_port
            if not remote_host or remote_port == 0:
                await context.abort(
                    grpc.StatusCode.INVALID_ARGUMENT,
                    "Invalid kv_transfer_params: "
                    "remote_host and remote_port must be set.",
                )
            logger.info(
                "Request %s: kv_transfer_params={remote_host=%s, remote_port=%d}",
                request_id,
                remote_host,
                remote_port,
            )

        # Build sampling params with detokenize=False
        sampling_params = self._sampling_params_from_proto(
            request.sampling_params,
            stream=request.stream,
            kv_transfer_params=request.kv_transfer_params
            if request.HasField("kv_transfer_params")
            else None,
        )
        tokenization_kwargs = self._tokenization_kwargs_from_proto(
            request.sampling_params
        )

        # Extract logprobs configuration
        num_logprobs = sampling_params.logprobs
        num_prompt_logprobs = sampling_params.prompt_logprobs

        # Track which indices have sent their first chunk
        seen_indices: set[int] = set()

        async for output in self.async_llm.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
            tokenization_kwargs=tokenization_kwargs,
        ):
            # For streaming, send chunks for EACH completion output (n outputs)
            if request.stream:
                for completion in output.outputs:
                    idx = completion.index
                    is_first = idx not in seen_indices
                    seen_indices.add(idx)

                    # Send chunk with delta data (Rust accumulates for vLLM)
                    yield self._chunk_response(
                        output,
                        completion=completion,
                        num_logprobs=num_logprobs,
                        num_prompt_logprobs=num_prompt_logprobs,
                        is_first_chunk=is_first,
                    )

                    # Send Complete when sequence finishes (n>1 support)
                    if completion.finish_reason:
                        yield self._complete_response(
                            output,
                            completion=completion,
                            num_logprobs=num_logprobs,
                            num_prompt_logprobs=num_prompt_logprobs,
                        )

            # For non-streaming, send complete response when finished
            if output.finished and not request.stream:
                for completion in output.outputs:
                    yield self._complete_response(
                        output,
                        completion=completion,
                        num_logprobs=num_logprobs,
                        num_prompt_logprobs=num_prompt_logprobs,
                    )

    except ValueError as e:
        # Invalid request error (equiv to 400).
        await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
    except Exception as e:
        logger.exception("Error in Generate for request %s", request_id)
        await context.abort(grpc.StatusCode.INTERNAL, str(e))

GetModelInfo async

GetModelInfo(
    request: GetModelInfoRequest, context: ServicerContext
) -> GetModelInfoResponse

Handle model info requests.

Parameters:

Name Type Description Default
request GetModelInfoRequest

The GetModelInfoRequest protobuf

required
context ServicerContext

gRPC context

required

Returns:

Type Description
GetModelInfoResponse

GetModelInfoResponse protobuf

Source code in vllm/entrypoints/grpc_server.py
async def GetModelInfo(
    self,
    request: vllm_engine_pb2.GetModelInfoRequest,
    context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.GetModelInfoResponse:
    """
    Handle model info requests.

    Args:
        request: The GetModelInfoRequest protobuf
        context: gRPC context

    Returns:
        GetModelInfoResponse protobuf
    """
    model_config = self.async_llm.model_config

    return vllm_engine_pb2.GetModelInfoResponse(
        model_path=model_config.model,
        is_generation=model_config.runner_type == "generate",
        max_context_length=model_config.max_model_len,
        vocab_size=model_config.get_vocab_size(),
        supports_vision=model_config.is_multimodal_model,
    )

GetServerInfo async

GetServerInfo(
    request: GetServerInfoRequest, context: ServicerContext
) -> GetServerInfoResponse

Handle server info requests.

Parameters:

Name Type Description Default
request GetServerInfoRequest

The GetServerInfoRequest protobuf

required
context ServicerContext

gRPC context

required

Returns:

Type Description
GetServerInfoResponse

GetServerInfoResponse protobuf

Source code in vllm/entrypoints/grpc_server.py
async def GetServerInfo(
    self,
    request: vllm_engine_pb2.GetServerInfoRequest,
    context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.GetServerInfoResponse:
    """
    Handle server info requests.

    Args:
        request: The GetServerInfoRequest protobuf
        context: gRPC context

    Returns:
        GetServerInfoResponse protobuf
    """
    num_requests = self.async_llm.output_processor.get_num_unfinished_requests()

    # Get KV transfer config if available
    kv_connector = ""
    kv_role = ""
    kv_transfer_config = self.async_llm.vllm_config.kv_transfer_config
    if kv_transfer_config is not None:
        kv_connector = kv_transfer_config.kv_connector or ""
        kv_role = kv_transfer_config.kv_role or ""

    return vllm_engine_pb2.GetServerInfoResponse(
        active_requests=num_requests,
        is_paused=False,  # TODO
        last_receive_timestamp=time.time(),  # TODO looks wrong?
        uptime_seconds=time.time() - self.start_time,
        server_type="vllm-grpc",
        kv_connector=kv_connector,
        kv_role=kv_role,
    )

HealthCheck async

HealthCheck(
    request: HealthCheckRequest, context: ServicerContext
) -> HealthCheckResponse

Handle health check requests.

Parameters:

Name Type Description Default
request HealthCheckRequest

The HealthCheckRequest protobuf

required
context ServicerContext

gRPC context

required

Returns:

Type Description
HealthCheckResponse

HealthCheckResponse protobuf

Source code in vllm/entrypoints/grpc_server.py
async def HealthCheck(
    self,
    request: vllm_engine_pb2.HealthCheckRequest,
    context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.HealthCheckResponse:
    """
    Handle health check requests.

    Args:
        request: The HealthCheckRequest protobuf
        context: gRPC context

    Returns:
        HealthCheckResponse protobuf
    """
    is_healthy = not self.async_llm.errored
    message = "Health" if is_healthy else "Engine is not alive"

    logger.info("HealthCheck request: healthy=%s, message=%s", is_healthy, message)

    return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message)

__init__

__init__(async_llm: AsyncLLM, start_time: float)

Initialize the servicer.

Parameters:

Name Type Description Default
async_llm AsyncLLM

The AsyncLLM instance

required
start_time float

The server start time, in seconds since epoch

required
Source code in vllm/entrypoints/grpc_server.py
def __init__(self, async_llm: AsyncLLM, start_time: float):
    """
    Initialize the servicer.

    Args:
        async_llm: The AsyncLLM instance
        start_time: The server start time, in seconds since epoch
    """
    self.async_llm = async_llm
    self.start_time = start_time
    logger.info("VllmEngineServicer initialized")

_build_input_logprobs staticmethod

_build_input_logprobs(
    prompt_logprobs: PromptLogprobs | None,
    prompt_token_ids: list[int],
    num_top_logprobs: int | None,
) -> InputLogProbs | None

Convert vLLM PromptLogprobs to proto InputLogProbs.

Parameters:

Name Type Description Default
prompt_logprobs PromptLogprobs | None

vLLM prompt logprobs (list of dict[int, Logprob] | None)

required
prompt_token_ids list[int]

Prompt token IDs

required
num_top_logprobs int | None

Number of top logprobs to include

required

Returns:

Type Description
InputLogProbs | None

InputLogProbs proto or None

Source code in vllm/entrypoints/grpc_server.py
@staticmethod
def _build_input_logprobs(
    prompt_logprobs: PromptLogprobs | None,
    prompt_token_ids: list[int],
    num_top_logprobs: int | None,
) -> vllm_engine_pb2.InputLogProbs | None:
    """
    Convert vLLM PromptLogprobs to proto InputLogProbs.

    Args:
        prompt_logprobs: vLLM prompt logprobs (list of dict[int, Logprob] | None)
        prompt_token_ids: Prompt token IDs
        num_top_logprobs: Number of top logprobs to include

    Returns:
        InputLogProbs proto or None
    """
    if not prompt_logprobs:
        return None

    proto = vllm_engine_pb2.InputLogProbs()

    for token_id, logprob_entry in zip(prompt_token_ids, prompt_logprobs):
        token_logprob = vllm_engine_pb2.InputTokenLogProb()

        # First token has no logprob (None)
        if logprob_entry is not None and token_id in logprob_entry:
            token_logprob.value = logprob_entry[token_id].logprob

        proto.token_logprobs.append(token_logprob)
        proto.token_ids.append(token_id)
        proto.top_logprobs.append(
            VllmEngineServicer._build_top_logprobs(logprob_entry, num_top_logprobs)
        )

    return proto if proto.token_ids else None

_build_output_logprobs staticmethod

_build_output_logprobs(
    logprobs: SampleLogprobs | None,
    token_ids: list[int],
    num_top_logprobs: int | None,
) -> OutputLogProbs | None

Convert vLLM SampleLogprobs to proto OutputLogProbs.

Parameters:

Name Type Description Default
logprobs SampleLogprobs | None

vLLM logprobs (list of dict[int, Logprob])

required
token_ids list[int]

Token IDs for each position

required
num_top_logprobs int | None

Number of top logprobs to include

required

Returns:

Type Description
OutputLogProbs | None

OutputLogProbs proto or None

Source code in vllm/entrypoints/grpc_server.py
@staticmethod
def _build_output_logprobs(
    logprobs: SampleLogprobs | None,
    token_ids: list[int],
    num_top_logprobs: int | None,
) -> vllm_engine_pb2.OutputLogProbs | None:
    """
    Convert vLLM SampleLogprobs to proto OutputLogProbs.

    Args:
        logprobs: vLLM logprobs (list of dict[int, Logprob])
        token_ids: Token IDs for each position
        num_top_logprobs: Number of top logprobs to include

    Returns:
        OutputLogProbs proto or None
    """
    if not logprobs:
        return None

    proto = vllm_engine_pb2.OutputLogProbs()

    for token_id, logprob_entry in zip(token_ids, logprobs):
        if logprob := logprob_entry.get(token_id):
            proto.token_logprobs.append(logprob.logprob)
            proto.token_ids.append(token_id)

            if num_top_logprobs:
                proto.top_logprobs.append(
                    VllmEngineServicer._build_top_logprobs(
                        logprob_entry, num_top_logprobs
                    )
                )

    return proto if proto.token_ids else None

_build_preprocessed_mm_inputs

_build_preprocessed_mm_inputs(
    tokenized: TokenizedInput, mm_proto: MultimodalInputs
) -> MultiModalInputs

Build vLLM MultiModalInputs from preprocessed proto data.

Bypasses HF processor entirely — pixel values and model-specific tensors were already computed by the Rust router. Field layouts (batched / flat / shared) are also determined by the router via batched_keys and flat_keys proto fields.

Source code in vllm/entrypoints/grpc_server.py
def _build_preprocessed_mm_inputs(
    self,
    tokenized: vllm_engine_pb2.TokenizedInput,
    mm_proto: vllm_engine_pb2.MultimodalInputs,
) -> VllmMultiModalInputs:
    """Build vLLM MultiModalInputs from preprocessed proto data.

    Bypasses HF processor entirely — pixel values and model-specific
    tensors were already computed by the Rust router.  Field layouts
    (batched / flat / shared) are also determined by the router via
    ``batched_keys`` and ``flat_keys`` proto fields.
    """
    prompt_token_ids = list(tokenized.input_ids)
    num_images = len(mm_proto.mm_placeholders)

    # Deserialize all tensors from proto
    hf_dict: dict[str, torch.Tensor] = {
        "pixel_values": _tensor_from_proto(mm_proto.pixel_values),
    }
    for key, td in mm_proto.model_specific_tensors.items():
        hf_dict[key] = _tensor_from_proto(td)

    # Cast floating-point tensors to model dtype (e.g. bfloat16).
    # This mirrors _postprocess_output in multimodal/processing/context.py
    # which is skipped when bypassing the HF processor.
    model_dtype = self.async_llm.model_config.dtype
    for key in hf_dict:
        if hf_dict[key].is_floating_point():
            hf_dict[key] = hf_dict[key].to(dtype=model_dtype)

    cpu_keys = set(mm_proto.keep_on_cpu_keys)

    # Field configs are fully determined by the Rust router.
    batched = set(mm_proto.batched_keys)
    flat = dict(mm_proto.flat_keys)
    fields_config: dict[str, MultiModalFieldConfig] = {}
    for key in hf_dict:
        on_cpu = key in cpu_keys
        if key in batched:
            fields_config[key] = MultiModalFieldConfig.batched(
                "image", keep_on_cpu=on_cpu
            )
        elif key in flat:
            sizes = hf_dict[flat[key]].flatten().to(torch.int64)
            fields_config[key] = MultiModalFieldConfig.flat_from_sizes(
                "image", sizes, keep_on_cpu=on_cpu
            )
        else:
            fields_config[key] = MultiModalFieldConfig.shared("image", num_images)

    batch_feature = BatchFeature(hf_dict, tensor_type="pt")
    mm_kwargs = MultiModalKwargsItems.from_hf_inputs(batch_feature, fields_config)

    # Build mm_hashes: dict[str, list[str]]
    mm_hashes: dict[str, list[str]] = {}
    if mm_proto.mm_hashes:
        mm_hashes["image"] = list(mm_proto.mm_hashes)

    # Build mm_placeholders: dict[str, list[PlaceholderRange]]
    # When structural tokens (e.g. <|image_start|>, separators) are present
    # in the placeholder range, we must set is_embed so vLLM only scatters
    # encoder embeddings into patch-token positions (im_token_id).
    mm_placeholders: dict[str, list[PlaceholderRange]] = {}
    if mm_proto.mm_placeholders:
        im_token_id = (
            mm_proto.im_token_id if mm_proto.HasField("im_token_id") else None
        )
        placeholders = []
        for p in mm_proto.mm_placeholders:
            is_embed = None
            if im_token_id is not None:
                token_slice = prompt_token_ids[p.offset : p.offset + p.length]
                mask = [t == im_token_id for t in token_slice]
                # Only set is_embed when there are non-embed positions,
                # otherwise None means "all positions are embeds" which is
                # both correct and avoids unnecessary overhead.
                if not all(mask):
                    is_embed = torch.tensor(mask, dtype=torch.bool)
            placeholders.append(
                PlaceholderRange(
                    offset=p.offset, length=p.length, is_embed=is_embed
                )
            )
        mm_placeholders["image"] = placeholders

    return mm_inputs(
        prompt_token_ids=prompt_token_ids,
        mm_kwargs=mm_kwargs,
        mm_hashes=mm_hashes,
        mm_placeholders=mm_placeholders,
        prompt=tokenized.original_text or None,
    )

_build_top_logprobs staticmethod

_build_top_logprobs(
    logprob_entry: dict, num_top_logprobs: int | None
) -> TopLogProbs

Build TopLogProbs proto from a logprob entry dict.

Source code in vllm/entrypoints/grpc_server.py
@staticmethod
def _build_top_logprobs(
    logprob_entry: dict,
    num_top_logprobs: int | None,
) -> vllm_engine_pb2.TopLogProbs:
    """Build TopLogProbs proto from a logprob entry dict."""
    top = vllm_engine_pb2.TopLogProbs()
    if num_top_logprobs and logprob_entry:
        sorted_entries = sorted(
            logprob_entry.items(),
            key=lambda x: x[1].logprob,
            reverse=True,
        )
        for tid, lp in functools.islice(sorted_entries, num_top_logprobs):
            top.token_ids.append(tid)
            top.values.append(lp.logprob)
    return top

_chunk_response staticmethod

_chunk_response(
    output: RequestOutput,
    completion: CompletionOutput | None = None,
    num_logprobs: int | None = None,
    num_prompt_logprobs: int | None = None,
    is_first_chunk: bool = False,
) -> GenerateResponse

Build a streaming chunk response from vLLM output. When output_kind=DELTA, vLLM returns only new tokens automatically.

Note: This sends DELTA logprobs (only for new tokens in this chunk). The Rust side is responsible for accumulating if needed.

Parameters:

Name Type Description Default
output RequestOutput

vLLM RequestOutput (with delta tokens when output_kind=DELTA)

required
completion CompletionOutput | None

Specific CompletionOutput to use (for n>1 support). If None, uses output.outputs[0] for backwards compatibility.

None
num_logprobs int | None

Number of top logprobs for output tokens

None
num_prompt_logprobs int | None

Number of top logprobs for prompt tokens

None
is_first_chunk bool

Whether this is the first chunk for this index (include input_logprobs only on first chunk)

False

Returns:

Type Description
GenerateResponse

GenerateResponse with chunk field set

Source code in vllm/entrypoints/grpc_server.py
@staticmethod
def _chunk_response(
    output: RequestOutput,
    completion: "CompletionOutput | None" = None,
    num_logprobs: int | None = None,
    num_prompt_logprobs: int | None = None,
    is_first_chunk: bool = False,
) -> vllm_engine_pb2.GenerateResponse:
    """
    Build a streaming chunk response from vLLM output.
    When output_kind=DELTA, vLLM returns only new tokens automatically.

    Note: This sends DELTA logprobs (only for new tokens in this chunk).
    The Rust side is responsible for accumulating if needed.

    Args:
        output: vLLM RequestOutput (with delta tokens when output_kind=DELTA)
        completion: Specific CompletionOutput to use (for n>1 support).
                   If None, uses output.outputs[0] for backwards compatibility.
        num_logprobs: Number of top logprobs for output tokens
        num_prompt_logprobs: Number of top logprobs for prompt tokens
        is_first_chunk: Whether this is the first chunk for this index
                       (include input_logprobs only on first chunk)

    Returns:
        GenerateResponse with chunk field set
    """
    # Use provided completion or fall back to first output
    if completion is None:
        completion = output.outputs[0] if output.outputs else None

    if completion is None:
        # Empty chunk
        return vllm_engine_pb2.GenerateResponse(
            chunk=vllm_engine_pb2.GenerateStreamChunk(
                token_ids=[],
                prompt_tokens=0,
                completion_tokens=0,
                cached_tokens=0,
                index=0,
            ),
        )

    # Build output logprobs for this chunk's tokens (delta, not cumulative)
    output_logprobs = VllmEngineServicer._build_output_logprobs(
        completion.logprobs, completion.token_ids, num_logprobs
    )

    # Build input logprobs only on first chunk for this index
    input_logprobs = None
    if is_first_chunk:
        input_logprobs = VllmEngineServicer._build_input_logprobs(
            output.prompt_logprobs,
            output.prompt_token_ids,
            num_prompt_logprobs,
        )

    # When output_kind=DELTA, completion.token_ids contains only new tokens
    # vLLM handles the delta logic internally
    # completion_tokens = delta count (client will accumulate)
    return vllm_engine_pb2.GenerateResponse(
        chunk=vllm_engine_pb2.GenerateStreamChunk(
            token_ids=completion.token_ids,
            prompt_tokens=len(output.prompt_token_ids)
            if output.prompt_token_ids
            else 0,
            completion_tokens=len(completion.token_ids),  # Delta count
            cached_tokens=output.num_cached_tokens,
            output_logprobs=output_logprobs,
            input_logprobs=input_logprobs,
            index=completion.index,
        ),
    )

_complete_response staticmethod

_complete_response(
    output: RequestOutput,
    completion: CompletionOutput | None = None,
    num_logprobs: int | None = None,
    num_prompt_logprobs: int | None = None,
) -> GenerateResponse

Build a final completion response from vLLM output.

For non-streaming (FINAL_ONLY): completion has all tokens and logprobs. For streaming (DELTA): completion has last delta; Rust accumulates.

Parameters:

Name Type Description Default
output RequestOutput

vLLM RequestOutput (finished=True)

required
completion CompletionOutput | None

Specific CompletionOutput to use (for n>1 support). If None, uses output.outputs[0] for backwards compatibility.

None
num_logprobs int | None

Number of top logprobs for output tokens

None
num_prompt_logprobs int | None

Number of top logprobs for prompt tokens

None

Returns:

Type Description
GenerateResponse

GenerateResponse with complete field set

Source code in vllm/entrypoints/grpc_server.py
@staticmethod
def _complete_response(
    output: RequestOutput,
    completion: "CompletionOutput | None" = None,
    num_logprobs: int | None = None,
    num_prompt_logprobs: int | None = None,
) -> vllm_engine_pb2.GenerateResponse:
    """
    Build a final completion response from vLLM output.

    For non-streaming (FINAL_ONLY): completion has all tokens and logprobs.
    For streaming (DELTA): completion has last delta; Rust accumulates.

    Args:
        output: vLLM RequestOutput (finished=True)
        completion: Specific CompletionOutput to use (for n>1 support).
                   If None, uses output.outputs[0] for backwards compatibility.
        num_logprobs: Number of top logprobs for output tokens
        num_prompt_logprobs: Number of top logprobs for prompt tokens

    Returns:
        GenerateResponse with complete field set
    """
    # Use provided completion or fall back to first output
    if completion is None:
        completion = output.outputs[0] if output.outputs else None

    if completion is None:
        # Empty completion
        return vllm_engine_pb2.GenerateResponse(
            complete=vllm_engine_pb2.GenerateComplete(
                output_ids=[],
                finish_reason="error",
                prompt_tokens=0,
                completion_tokens=0,
                cached_tokens=0,
                index=0,
            ),
        )

    # Build output logprobs from completion's data
    # For non-streaming: this has all logprobs
    # For streaming: this has only last delta (Rust accumulates from chunks)
    output_logprobs = VllmEngineServicer._build_output_logprobs(
        completion.logprobs, completion.token_ids, num_logprobs
    )

    # Build input logprobs
    input_logprobs = VllmEngineServicer._build_input_logprobs(
        output.prompt_logprobs,
        output.prompt_token_ids,
        num_prompt_logprobs,
    )

    # Build kv_transfer_params if present (Mooncake PD)
    kv_transfer_params = None
    if output.kv_transfer_params:
        kv_transfer_params = vllm_engine_pb2.KvTransferParams(
            remote_host=output.kv_transfer_params.get("remote_host", ""),
            remote_port=output.kv_transfer_params.get("remote_port", 0),
        )

    # Build complete response
    # When streaming (DELTA mode): completion.token_ids will be empty/last delta
    # When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens
    # Client will accumulate token counts for streaming
    return vllm_engine_pb2.GenerateResponse(
        complete=vllm_engine_pb2.GenerateComplete(
            output_ids=completion.token_ids,
            finish_reason=completion.finish_reason or "stop",
            prompt_tokens=len(output.prompt_token_ids)
            if output.prompt_token_ids
            else 0,
            completion_tokens=len(completion.token_ids),
            cached_tokens=output.num_cached_tokens,
            output_logprobs=output_logprobs,
            input_logprobs=input_logprobs,
            index=completion.index,
            kv_transfer_params=kv_transfer_params,
        ),
    )

_sampling_params_from_proto staticmethod

_sampling_params_from_proto(
    params: SamplingParams,
    stream: bool = True,
    kv_transfer_params: KvTransferParams | None = None,
) -> SamplingParams

Convert protobuf SamplingParams to vLLM SamplingParams.

Parameters:

Name Type Description Default
params SamplingParams

Protobuf SamplingParams message

required
stream bool

Whether streaming is enabled

True
kv_transfer_params KvTransferParams | None

KV transfer params proto for Mooncake PD

None

Returns:

Type Description
SamplingParams

vLLM SamplingParams with detokenize=False and structured_outputs

Source code in vllm/entrypoints/grpc_server.py
@staticmethod
def _sampling_params_from_proto(
    params: vllm_engine_pb2.SamplingParams,
    stream: bool = True,
    kv_transfer_params: vllm_engine_pb2.KvTransferParams | None = None,
) -> SamplingParams:
    """
    Convert protobuf SamplingParams to vLLM SamplingParams.

    Args:
        params: Protobuf SamplingParams message
        stream: Whether streaming is enabled
        kv_transfer_params: KV transfer params proto for Mooncake PD

    Returns:
        vLLM SamplingParams with detokenize=False and structured_outputs
    """
    # Build stop sequences
    stop = list(params.stop) if params.stop else None
    stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None

    # Handle structured outputs constraints
    structured_outputs = None
    constraint_field = params.WhichOneof("constraint")
    if constraint_field:
        if constraint_field == "json_schema":
            structured_outputs = StructuredOutputsParams(json=params.json_schema)
        elif constraint_field == "regex":
            structured_outputs = StructuredOutputsParams(regex=params.regex)
        elif constraint_field == "grammar":
            structured_outputs = StructuredOutputsParams(grammar=params.grammar)
        elif constraint_field == "structural_tag":
            structured_outputs = StructuredOutputsParams(
                structural_tag=params.structural_tag
            )
        elif constraint_field == "json_object":
            structured_outputs = StructuredOutputsParams(
                json_object=params.json_object
            )
        elif constraint_field == "choice":
            structured_outputs = StructuredOutputsParams(
                choice=list(params.choice.choices)
            )

    # Build extra_args for kv_transfer_params (Mooncake PD)
    extra_args = None
    if kv_transfer_params:
        extra_args = {
            "kv_transfer_params": {
                "remote_host": kv_transfer_params.remote_host,
                "remote_port": kv_transfer_params.remote_port,
            }
        }

    # Create SamplingParams
    # output_kind=DELTA: Return only new tokens in each chunk (for streaming)
    return SamplingParams(
        temperature=params.temperature if params.HasField("temperature") else 1.0,
        top_p=params.top_p if params.top_p != 0.0 else 1.0,
        top_k=params.top_k,
        min_p=params.min_p,
        frequency_penalty=params.frequency_penalty,
        presence_penalty=params.presence_penalty,
        repetition_penalty=params.repetition_penalty
        if params.repetition_penalty != 0.0
        else 1.0,
        max_tokens=params.max_tokens if params.HasField("max_tokens") else None,
        min_tokens=params.min_tokens,
        stop=stop,
        stop_token_ids=stop_token_ids,
        skip_special_tokens=params.skip_special_tokens,
        spaces_between_special_tokens=params.spaces_between_special_tokens,
        ignore_eos=params.ignore_eos,
        n=params.n if params.n > 0 else 1,
        logprobs=params.logprobs if params.HasField("logprobs") else None,
        prompt_logprobs=params.prompt_logprobs
        if params.HasField("prompt_logprobs")
        else None,
        seed=params.seed if params.HasField("seed") else None,
        include_stop_str_in_output=params.include_stop_str_in_output,
        logit_bias=dict(params.logit_bias) if params.logit_bias else None,
        structured_outputs=structured_outputs,
        extra_args=extra_args,
        # detokenize must be True if stop strings are used
        detokenize=bool(stop),
        output_kind=RequestOutputKind.DELTA
        if stream
        else RequestOutputKind.FINAL_ONLY,
    )

_tensor_from_proto

_tensor_from_proto(td: TensorData) -> Tensor

Deserialize a TensorData proto message into a torch.Tensor.

Source code in vllm/entrypoints/grpc_server.py
def _tensor_from_proto(td: vllm_engine_pb2.TensorData) -> torch.Tensor:
    """Deserialize a TensorData proto message into a torch.Tensor."""
    np_dtype = _PROTO_DTYPE_MAP.get(td.dtype)
    if np_dtype is None:
        raise ValueError(f"Unsupported proto tensor dtype: {td.dtype!r}")
    arr = np.frombuffer(td.data, dtype=np_dtype).reshape(list(td.shape))
    return torch.from_numpy(arr.copy())

main

main()

Main entry point.

Source code in vllm/entrypoints/grpc_server.py
def main():
    """Main entry point."""
    parser = FlexibleArgumentParser(
        description="vLLM gRPC Server",
    )

    # Server args
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="Host to bind gRPC server to",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=50051,
        help="Port to bind gRPC server to",
    )
    parser.add_argument(
        "--disable-log-stats-server",
        action="store_true",
        help="Disable stats logging on server side",
    )

    # Add vLLM engine args
    parser = AsyncEngineArgs.add_cli_args(parser)

    args = parser.parse_args()

    # Run server
    try:
        uvloop.run(serve_grpc(args))
    except Exception as e:
        logger.exception("Server failed: %s", e)
        sys.exit(1)

serve_grpc async

serve_grpc(args: Namespace)

Main serving function.

Parameters:

Name Type Description Default
args Namespace

Parsed command line arguments

required
Source code in vllm/entrypoints/grpc_server.py
async def serve_grpc(args: argparse.Namespace):
    """
    Main serving function.

    Args:
        args: Parsed command line arguments
    """
    log_version_and_model(logger, VLLM_VERSION, args.model)
    logger.info("vLLM gRPC server args: %s", args)

    start_time = time.time()

    # Create engine args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Build vLLM config
    vllm_config = engine_args.create_engine_config(
        usage_context=UsageContext.OPENAI_API_SERVER
    )

    # Create AsyncLLM
    async_llm = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=UsageContext.OPENAI_API_SERVER,
        enable_log_requests=args.enable_log_requests,
        disable_log_stats=args.disable_log_stats_server,
    )

    # Create servicer
    servicer = VllmEngineServicer(async_llm, start_time)

    # Create gRPC server
    server = grpc.aio.server(
        options=[
            ("grpc.max_send_message_length", -1),
            ("grpc.max_receive_message_length", -1),
        ],
    )

    # Add servicer to server
    vllm_engine_pb2_grpc.add_VllmEngineServicer_to_server(servicer, server)

    # Enable reflection for grpcurl and other tools
    service_names = (
        vllm_engine_pb2.DESCRIPTOR.services_by_name["VllmEngine"].full_name,
        reflection.SERVICE_NAME,
    )
    reflection.enable_server_reflection(service_names, server)

    # Bind to address
    address = f"{args.host}:{args.port}"
    server.add_insecure_port(address)

    # Start server
    await server.start()
    logger.info("vLLM gRPC server started on %s", address)
    logger.info("Server is ready to accept requests")

    # Handle shutdown signals
    loop = asyncio.get_running_loop()
    stop_event = asyncio.Event()

    def signal_handler():
        logger.info("Received shutdown signal")
        stop_event.set()

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, signal_handler)

    # Serve until shutdown signal
    try:
        await stop_event.wait()
    except KeyboardInterrupt:
        logger.info("Interrupted by user")
    finally:
        logger.info("Shutting down vLLM gRPC server...")

        # Stop gRPC server
        await server.stop(grace=5.0)
        logger.info("gRPC server stopped")

        # Shutdown AsyncLLM
        async_llm.shutdown()
        logger.info("AsyncLLM engine stopped")

        logger.info("Shutdown complete")