Skip to content

Conversation

wallashss
Copy link
Collaborator

@wallashss wallashss commented Aug 21, 2025

Description

This PR add initial support for FP8 on Continuous Batching

Changes

  • Included FP8 logic in spyre.py which need to set the scale to the weights for CB
  • [UDPATE]: Added padding for bs=1
  • [EXTRA] Added decoding of generation on scheduler test_spyre_cb_scheduler_steps.py for better debugging later

TODOS

  • Set tolerance for logprobs difference for quantized model during tests
  • Currently, the matrix of tests does not include tests/e2e/test_spyre_cb_scheduler_steps.py for FP8, we have to figure it out a clean way to include them. Moreover, most of these tests are failing and need better thinking before activating them.

Copy link

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
docs: improved docs

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Comment on lines 329 to 332
if self.model.model.dtype in [torch.float8_e4m3fn]:
mask = mask.to(torch.float32)
else:
mask = mask.to(self.model.model.dtype)
Copy link
Collaborator

@prashantgupta24 prashantgupta24 Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we control the self.model.model.dtype (through the get_dtype function), can we not make sure that self.model.model.dtype is always what we want it to be?

Copy link
Collaborator

@prashantgupta24 prashantgupta24 Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to keep such unintuitive code in one place (unintuitive because I would have expected fp8 to work but it's not supposed to work that way)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which place? I didn't understand.

Can I put there a TODO there and you fix that in your following PR?

Copy link
Collaborator

@prashantgupta24 prashantgupta24 Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the get_dtype function in spyre.py - that's where we get the dtype which eventually is what gets plugged to self.model.model.dtype) -> I'm not 100% sure if that is the correct way though, worth adding a TODO

Copy link
Collaborator

@prashantgupta24 prashantgupta24 Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now we have this for CB:

def get_dtype(self) -> torch.dtype:
        # Get the model's data type
        # This should be:
        # FP32 for un-quantized models on cpu
        # FP16 for un-quantized models on spyre
        # FP8 (float8_e4m3fn) for quantized models
        # (only fp8 quantization is supported)
        if self.model_config.quantization:
            return torch.float8_e4m3fn
        else:
            if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
                return torch.float16
            else:
                return torch.float32

If we don't want to work with torch.float8_e4m3fn (since you explicitly check in your if condition above), maybe we don't return that as the value within if self.model_config.quantization and instead return fp32 directly? That way we can get rid of the if condition you wrote above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think understood your point, but I think change the behavior of this method is misleading. I did a quick search and this method is only used to set the dtype of the model in the constructor. I think change to something else would be wrong. The check I did is only to prevent mask tensor to use torch.float8_e4m3fn, a corner case that I identified during development, maybe elsewhere would be fine to keep the right dtype.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yeah the other place where self.dtype is used is in the scale of past_key_value_states - it would be ugly if that needs fp8 and this needs fp32 :(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that for running on spyre we wanted to set the mask as fp16 though, not necessarily fp32?

Maybe we should scrap model.dtype completely, and instead specify the dtypes that we need for specific tensors. e.g.

def self.get_mask_dtype(self):
  return fp16 if spyre else fp32

def self.get_kv_cache_dtype(self):
  return fp8 if self.is_fp8_model else fp16

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I missed the last comment from Joe.

I changed to something similar to his suggestion. Running the tests to check if everything's still alright.

BTW, this change breaks the cache.

@@ -430,25 +438,32 @@ def _set_past_key_value_states(self, num_blocks) -> None:
# TODO: This does not work yet. The scale needs to be handled, see:
# https://github.com/foundation-model-stack/aiu-fms-testing-utils/blob/v0.1.0rc3/aiu_fms_testing_utils/utils/paged.py#L306-L319
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
batch_size = max(2, self.scheduler_config.max_num_seqs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is FP8 only supported with batch size <= 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least batch size. Currently, I found that had to set bs=2, but maybe it is not necessary for now. I'll revert it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe we should update the scheduler_config in Platform.check_and_update_config.

Copy link
Collaborator Author

@wallashss wallashss Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, The intention here is to fix/workaround the compiler limitation while in platform we change the behavior of the system as whole. There's something similar there:

# min value 2 needed for VLLM_DT_MAX_BATCH_SIZE (compiler constraint)
# Note that we can still have decodes of batch size 1 as the env var
# only concerns the max batch size.
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(
    max(vllm_config.scheduler_config.max_num_seqs, 2))

But setting this env does NOT* change original setup of vllm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way is probably fine, but does everything still work if the user sets --max-num-seqs 1? If not then I'd prefer overriding it in platform.py

Signed-off-by: Wallas Santos <[email protected]>
test: minor improvement on tolerance for quantized models

Signed-off-by: Wallas Santos <[email protected]>
@wallashss wallashss changed the title [WIP] feat: FP8 support on continuous batch [WIP] feat: FP8 initial support on continuous batching Aug 26, 2025
@wallashss wallashss changed the title [WIP] feat: FP8 initial support on continuous batching feat: FP8 initial support on continuous batching Aug 26, 2025
@wallashss wallashss marked this pull request as ready for review August 26, 2025 17:50
@@ -212,7 +221,12 @@ def check_scheduler_inference_steps(
new_token_ids[0])
collected_outputs[output.request_id]["logprobs"].append(
new_logprobs[0][0])
collected_outputs[output.request_id]["tokens"].append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand why the decoding is needed. Is it just to print text for debugging instead of token indices?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these tests were failing to me, and just seeing the logs I didn't know why was failing. For example, the generation was diverging due to a different token choice, but I didn't know if it was gibberish or something reasonable because of difference of logprobs. Also, it is helpful to get the exact prompt and test in a different environment to see the response out of a batch . For instance, the prompts of these tests are slight different of the chicken soup prompts, they were truncated to have an exact count of tokens.

That's why I think the decoding is helpful for the debugging of these tests.

@@ -430,25 +438,32 @@ def _set_past_key_value_states(self, num_blocks) -> None:
# TODO: This does not work yet. The scale needs to be handled, see:
# https://github.com/foundation-model-stack/aiu-fms-testing-utils/blob/v0.1.0rc3/aiu_fms_testing_utils/utils/paged.py#L306-L319
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
batch_size = max(2, self.scheduler_config.max_num_seqs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then maybe we should update the scheduler_config in Platform.check_and_update_config.

@@ -24,6 +25,9 @@
ISCLOSE_REL_TOL_CPU = 0.35
ISCLOSE_REL_TOL_SPYRE = 0.35

# TODO: improve this
ISCLOSE_REL_TOL_QUANTIZATION = 0.451
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhhh.... at what point is this just tolerance too loose?

isclose takes an absolute tolerance as well, maybe we should instead start maintaining both tolerances. For example if the two logprobs we're comparing are -9.1 and -15.2 then we should fail, but with -0.000001 and -0.000002 maybe we can pass.

If we add ISCLOSE_ABS_TOL = 0.0001 then how tight can we make the relative tolerance again?

Copy link
Collaborator

@joerunde joerunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lpgtm! With a slight preference to adding an absolute tolerance so we don't have to relax the relative one so much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants