-
Notifications
You must be signed in to change notification settings - Fork 21
feat: FP8 initial support on continuous batching #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Wallas Santos <[email protected]>
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
docs: improved docs Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
if self.model.model.dtype in [torch.float8_e4m3fn]: | ||
mask = mask.to(torch.float32) | ||
else: | ||
mask = mask.to(self.model.model.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we control the self.model.model.dtype
(through the get_dtype
function), can we not make sure that self.model.model.dtype
is always what we want it to be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to keep such unintuitive code in one place (unintuitive because I would have expected fp8 to work but it's not supposed to work that way)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which place? I didn't understand.
Can I put there a TODO there and you fix that in your following PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant the get_dtype
function in spyre.py
- that's where we get the dtype
which eventually is what gets plugged to self.model.model.dtype
) -> I'm not 100% sure if that is the correct way though, worth adding a TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So right now we have this for CB:
def get_dtype(self) -> torch.dtype:
# Get the model's data type
# This should be:
# FP32 for un-quantized models on cpu
# FP16 for un-quantized models on spyre
# FP8 (float8_e4m3fn) for quantized models
# (only fp8 quantization is supported)
if self.model_config.quantization:
return torch.float8_e4m3fn
else:
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
return torch.float16
else:
return torch.float32
If we don't want to work with torch.float8_e4m3fn
(since you explicitly check in your if condition above), maybe we don't return that as the value within if self.model_config.quantization
and instead return fp32 directly? That way we can get rid of the if condition you wrote above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm I think understood your point, but I think change the behavior of this method is misleading. I did a quick search and this method is only used to set the dtype of the model in the constructor. I think change to something else would be wrong. The check I did is only to prevent mask tensor to use torch.float8_e4m3fn
, a corner case that I identified during development, maybe elsewhere would be fine to keep the right dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually yeah the other place where self.dtype
is used is in the scale of past_key_value_states
- it would be ugly if that needs fp8 and this needs fp32 :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that for running on spyre we wanted to set the mask as fp16 though, not necessarily fp32?
Maybe we should scrap model.dtype
completely, and instead specify the dtypes that we need for specific tensors. e.g.
def self.get_mask_dtype(self):
return fp16 if spyre else fp32
def self.get_kv_cache_dtype(self):
return fp8 if self.is_fp8_model else fp16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I missed the last comment from Joe.
I changed to something similar to his suggestion. Running the tests to check if everything's still alright.
BTW, this change breaks the cache.
@@ -430,25 +438,32 @@ def _set_past_key_value_states(self, num_blocks) -> None: | |||
# TODO: This does not work yet. The scale needs to be handled, see: | |||
# https://github.com/foundation-model-stack/aiu-fms-testing-utils/blob/v0.1.0rc3/aiu_fms_testing_utils/utils/paged.py#L306-L319 | |||
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor | |||
batch_size = max(2, self.scheduler_config.max_num_seqs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is FP8 only supported with batch size <= 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least batch size. Currently, I found that had to set bs=2, but maybe it is not necessary for now. I'll revert it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then maybe we should update the scheduler_config in Platform.check_and_update_config
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, The intention here is to fix/workaround the compiler limitation while in platform we change the behavior of the system as whole. There's something similar there:
# min value 2 needed for VLLM_DT_MAX_BATCH_SIZE (compiler constraint)
# Note that we can still have decodes of batch size 1 as the env var
# only concerns the max batch size.
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(
max(vllm_config.scheduler_config.max_num_seqs, 2))
But setting this env does NOT* change original setup of vllm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way is probably fine, but does everything still work if the user sets --max-num-seqs 1
? If not then I'd prefer overriding it in platform.py
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
test: minor improvement on tolerance for quantized models Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
@@ -212,7 +221,12 @@ def check_scheduler_inference_steps( | |||
new_token_ids[0]) | |||
collected_outputs[output.request_id]["logprobs"].append( | |||
new_logprobs[0][0]) | |||
collected_outputs[output.request_id]["tokens"].append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand why the decoding is needed. Is it just to print text for debugging instead of token indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these tests were failing to me, and just seeing the logs I didn't know why was failing. For example, the generation was diverging due to a different token choice, but I didn't know if it was gibberish or something reasonable because of difference of logprobs. Also, it is helpful to get the exact prompt and test in a different environment to see the response out of a batch . For instance, the prompts of these tests are slight different of the chicken soup prompts, they were truncated to have an exact count of tokens.
That's why I think the decoding is helpful for the debugging of these tests.
@@ -430,25 +438,32 @@ def _set_past_key_value_states(self, num_blocks) -> None: | |||
# TODO: This does not work yet. The scale needs to be handled, see: | |||
# https://github.com/foundation-model-stack/aiu-fms-testing-utils/blob/v0.1.0rc3/aiu_fms_testing_utils/utils/paged.py#L306-L319 | |||
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor | |||
batch_size = max(2, self.scheduler_config.max_num_seqs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then maybe we should update the scheduler_config in Platform.check_and_update_config
.
Signed-off-by: Wallas Santos <[email protected]>
tests/spyre_util.py
Outdated
@@ -24,6 +25,9 @@ | |||
ISCLOSE_REL_TOL_CPU = 0.35 | |||
ISCLOSE_REL_TOL_SPYRE = 0.35 | |||
|
|||
# TODO: improve this | |||
ISCLOSE_REL_TOL_QUANTIZATION = 0.451 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhhh.... at what point is this just tolerance too loose?
isclose
takes an absolute tolerance as well, maybe we should instead start maintaining both tolerances. For example if the two logprobs we're comparing are -9.1
and -15.2
then we should fail, but with -0.000001
and -0.000002
maybe we can pass.
If we add ISCLOSE_ABS_TOL = 0.0001
then how tight can we make the relative tolerance again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lpgtm! With a slight preference to adding an absolute tolerance so we don't have to relax the relative one so much
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Description
This PR add initial support for FP8 on Continuous Batching
Changes
test_spyre_cb_scheduler_steps.py
for better debugging laterTODOS
tests/e2e/test_spyre_cb_scheduler_steps.py
for FP8, we have to figure it out a clean way to include them. Moreover, most of these tests are failing and need better thinking before activating them.