-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[PROTON] Simplify proton viewer APIs for bench_mlp analysis #6452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d2bd2dc
52d512d
2a42b3f
7145987
d0a5243
7f628ae
abf1362
231d395
c9b31a6
c5d9fef
81384a4
7590578
8c567f7
8a7a43e
ff6d8dd
acf9cc0
30cd0d0
eaaf908
a1a83d6
7caf3b2
f5a62ad
04b29c8
8a73a20
8520d8f
a784f5c
76c3187
05cc177
375fab9
d94c117
1e30fc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,4 +9,4 @@ | |
| profile, | ||
| DEFAULT_PROFILE_NAME, | ||
| ) | ||
| from . import context | ||
| from . import context, specs | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| flops_by_device = { | ||
| "CUDA": { | ||
| "80": | ||
| lambda width, **kwargs: 624e12 / (width / 8), | ||
| "89": | ||
| lambda width, **kwargs: (330.3 * 1e12) / (width / 8), # TODO(Keren): Implement fp16 acc-> 660.6 fp8 | ||
| "90": | ||
| lambda width, num_sms, clock_rate, **kwargs: ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / | ||
| (width / 8), | ||
| "100": | ||
| lambda width, num_sms, clock_rate, **kwargs: (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8), | ||
| }, | ||
| "HIP": { | ||
| "gfx90a": lambda width, **kwargs: 383e12 / (width / 8), | ||
| "gfx942": lambda width, **kwargs: 2614.9e12 / (width / 8), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @antiagainst shall we drop a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. The spec is unavailable right now. For mi300x and mi325x you can see the spec in https://github.com/triton-lang/triton/pull/6513/files#diff-5e6a8d3fc5ad9de85fc09ead926355cc19497d3056c368d463bf4c626ce68540. Can drop
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So only "gfx90a" and "gfx942" at this moment?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also cc @ptillet for any additional comments
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup-- |
||
| }, | ||
| } | ||
|
|
||
|
|
||
| def max_flops(device_type, arch, width, num_sms, clock_rate): | ||
| """ | ||
| Calculate the maximum FLOPS for a given device type and width. | ||
|
|
||
| Args: | ||
| device_type (str): The type of device (e.g., "CUDA", "HIP"). | ||
| arch (str): The architecture of the device (e.g., "80", "90"). | ||
| width (int): The width in bits. | ||
| num_sms (int): The number of streaming multiprocessors. | ||
| clock_rate (float): The clock rate in GHz. | ||
|
|
||
| Returns: | ||
| float: The maximum FLOPS for the given device type and width. | ||
| """ | ||
| if device_type not in flops_by_device: | ||
| raise ValueError(f"Unsupported device type: {device_type}") | ||
|
|
||
| if arch not in flops_by_device[device_type]: | ||
| raise ValueError(f"Unsupported architecture: {arch}") | ||
|
|
||
| flops_func = flops_by_device[device_type][arch] | ||
|
|
||
| return flops_func(width, num_sms=num_sms, clock_rate=clock_rate) | ||
|
|
||
|
|
||
| def max_bps(bus_width, memory_clock_rate): | ||
| """ | ||
| Calculate the maximum bytes per second for a given bus width and memory clock rate. | ||
|
|
||
| Args: | ||
| bus_width (int): The bus width in bits. | ||
| memory_clock_rate (float): The memory clock rate in GHz. | ||
|
|
||
| Returns: | ||
| float: The maximum bytes per second. | ||
| """ | ||
| return 2 * bus_width * memory_clock_rate * 1e3 / 8 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This dependency is causing problems internally because it depends on numpy 1. It appears that hatchet is not actually used inside triton_kernels here, so I think it should be possible to remove this dependency.