Skip to content

Swa padding brcm#4

Draft
sudhakarsingh27 wants to merge 31 commits intomain_baseline_for_swa_padding_brcmfrom
swa_padding_brcm
Draft

Swa padding brcm#4
sudhakarsingh27 wants to merge 31 commits intomain_baseline_for_swa_padding_brcmfrom
swa_padding_brcm

Conversation

@sudhakarsingh27
Copy link
Owner

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cyanguwa and others added 30 commits December 11, 2024 21:30
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* fix ctx.aval_out indexing for workspace
* add cudnn init to prepare phase of norm custom calls
* add thread_local for norm registry instance
---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Add Jeremy to ci users

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* softmax custom calls with correct encapsulates

* rm jax deprecated features

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
…VIDIA#1358)

* draft implementation of fsdp2 fp8 all gather

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* fix the convergence issue

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* Add warning

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable lint error

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the lint error

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* fix lint error

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint error

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint error

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* add comments

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* add ref

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* add related tests

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
add max_t for KV

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* Add util functions to attn_mask_type

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add util functions to qkv_layout

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Fix THD cross reference code

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Remove explicit segment_pad, encoding it to segment_ids

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add jax.jit, replace _token with segment_ids, rename bias shape enum

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add comment for make_mask

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Clean code

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add doc strings for the added functions

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Remove cache for fa deterministic which causes UT failed

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Rename fixture to avoid conflict

Signed-off-by: Reese Wang <rewang@nvidia.com>

---------

Signed-off-by: Reese Wang <rewang@nvidia.com>
add weights_only=False for torch.load

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* WIP: fix get_swa_mask for padding

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix mask type setting

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix the order of checking valid swa and changing mask type

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revamp to get full mask

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…#1356)

* Move test distributed encoder to L0 distributed test suit

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: Reese Wang <rewang@nvidia.com>
…sal (NVIDIA#1378)

* add swa (left,0) + padding + brcm support

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* final fixes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* upgrade to FE 1.9-rc

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip thd + CP + fused attn tests for cuDNN 9.6+ due to different stats shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants