Skip to content

Commit d407580

Browse files
authored
Add speedup metric for TorchAO (#6118)
This is my initial attempt to add the speedup metric for TorchAO. This is done by comparing the gain of `autoquant` v.s. `noquant`. This is by no means the best approach because it requires custom logic for TorchAO on the dashboard. On the other hand, it's easy to implement and I think it's better to have the UX done first to gather early feedbacks from @jerryzh168 and the rest of ao team first. IMO, better approaches would be to either 1) set the speedup metric on TorchAO side or 2) compute the speed up metric on ClickHouse. Both are more involved and requires further design discussion. ### Testing https://torchci-git-fork-huydhn-add-speedup-llm-dashboard-fbopensource.vercel.app/benchmark/llms?repoName=pytorch%2Fao
1 parent 6f108ab commit d407580

File tree

6 files changed

+134
-38
lines changed

6 files changed

+134
-38
lines changed

torchci/components/benchmark/llms/ModelGraphPanel.tsx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
TimeSeriesPanelWithData,
1919
} from "components/metrics/panels/TimeSeriesPanel";
2020
import dayjs from "dayjs";
21+
import { computeSpeedup } from "lib/benchmark/aoUtils";
2122
import { useBenchmark } from "lib/benchmark/llmUtils";
2223
import { BranchAndCommit } from "lib/types";
2324

@@ -26,6 +27,7 @@ const GRAPH_ROW_HEIGHT = 245;
2627
export function GraphPanel({
2728
queryParams,
2829
granularity,
30+
repoName,
2931
modelName,
3032
backendName,
3133
dtypeName,
@@ -36,6 +38,7 @@ export function GraphPanel({
3638
}: {
3739
queryParams: { [key: string]: any };
3840
granularity: Granularity;
41+
repoName: string;
3942
modelName: string;
4043
backendName: string;
4144
dtypeName: string;
@@ -65,6 +68,8 @@ export function GraphPanel({
6568
return <></>;
6669
}
6770

71+
const dataWithSpeedup = computeSpeedup(repoName, data);
72+
6873
// Clamp to the nearest granularity (e.g. nearest hour) so that the times will
6974
// align with the data we get from the database
7075
const startTime = dayjs(queryParams["startTime"]).startOf(granularity);
@@ -79,7 +84,7 @@ export function GraphPanel({
7984
const chartData: { [k: string]: any } = {};
8085
const graphSeries: { [k: string]: any } = {};
8186
metricNames.forEach((metric: string) => {
82-
chartData[metric] = data
87+
chartData[metric] = dataWithSpeedup
8388
.filter((record: LLMsBenchmarkData) => {
8489
return (
8590
record.model === modelName &&

torchci/components/benchmark/llms/SummaryPanel.tsx

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,43 +63,51 @@ export function SummaryPanel({
6363
},
6464
renderCell: (params: GridRenderCellParams<any>) => {
6565
const model = params.value.model;
66-
const dtype = params.value.dtype;
67-
const deviceArch = `${params.value.device} (${params.value.arch})`;
6866
if (model === undefined) {
6967
return `Invalid model name`;
7068
}
71-
if (dtype === undefined) {
72-
return `Invalid dtype for model ${model}`;
73-
}
7469

70+
const dtype =
71+
params.value.dtype !== undefined
72+
? `&dtypeName=${encodeURIComponent(params.value.dtype)}`
73+
: "";
7574
const backend =
7675
params.value.backend !== undefined
77-
? `&${encodeURIComponent(params.value.backend)}`
76+
? `&backendName=${encodeURIComponent(params.value.backend)}`
7877
: "";
78+
const deviceArch = `${params.value.device} (${params.value.arch})`;
79+
7980
const url = `/benchmark/llms?startTime=${startTime}&stopTime=${stopTime}&granularity=${granularity}&repoName=${encodeURIComponent(
8081
repoName
8182
)}&modelName=${encodeURIComponent(
8283
model
83-
)}${backend}&dtypeName=${encodeURIComponent(
84-
dtype
85-
)}&deviceName=${encodeURIComponent(deviceArch)}`;
84+
)}${backend}${dtype}&deviceName=${encodeURIComponent(deviceArch)}`;
8685

8786
const isNewModel = params.value.l === undefined ? "(NEW!) " : "";
8887
const isModelStopRunning = params.value.r === undefined ? "❌" : "";
8988

90-
const displayName = model.includes(dtype)
91-
? model
92-
: `${model} (${dtype})`;
9389
return (
9490
<a href={url}>
9591
{isNewModel}
96-
{isModelStopRunning}&nbsp;<b>{displayName}</b>
92+
{isModelStopRunning}&nbsp;<b>{model}</b>
9793
</a>
9894
);
9995
},
10096
},
10197
];
10298

99+
const hasDtype = data.length > 0 && "dtype" in data[0] ? true : false;
100+
if (hasDtype) {
101+
columns.push({
102+
field: "dtype",
103+
headerName: "Quantization",
104+
flex: 1,
105+
renderCell: (params: GridRenderCellParams<any>) => {
106+
return `${params.value}`;
107+
},
108+
});
109+
}
110+
103111
const hasBackend = data.length > 0 && "backend" in data[0] ? true : false;
104112
if (hasBackend) {
105113
columns.push({
@@ -155,18 +163,23 @@ export function SummaryPanel({
155163
return styles.error;
156164
}
157165

158-
// Higher value
159-
if (r - l > RELATIVE_THRESHOLD * l) {
160-
return IS_INCREASING_METRIC_VALUE_GOOD[metric]
161-
? styles.ok
162-
: styles.error;
163-
}
164-
165-
// Lower value
166-
if (l - r > RELATIVE_THRESHOLD * r) {
167-
return IS_INCREASING_METRIC_VALUE_GOOD[metric]
168-
? styles.error
169-
: styles.ok;
166+
if (metric in IS_INCREASING_METRIC_VALUE_GOOD) {
167+
// Higher value
168+
if (r - l > RELATIVE_THRESHOLD * l) {
169+
return IS_INCREASING_METRIC_VALUE_GOOD[metric]
170+
? styles.ok
171+
: styles.error;
172+
}
173+
174+
// Lower value
175+
if (l - r > RELATIVE_THRESHOLD * r) {
176+
return IS_INCREASING_METRIC_VALUE_GOOD[metric]
177+
? styles.error
178+
: styles.ok;
179+
}
180+
} else {
181+
// No data
182+
return "";
170183
}
171184
}
172185

torchci/components/benchmark/llms/common.tsx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { BranchAndCommit } from "lib/types";
22

3-
export const REPOS = ["pytorch/pytorch", "pytorch/executorch"];
3+
export const REPOS = ["pytorch/pytorch", "pytorch/executorch", "pytorch/ao"];
44
export const REPO_TO_BENCHMARKS: { [k: string]: string[] } = {
55
"pytorch/pytorch": ["PyTorch gpt-fast benchmark"],
66
"pytorch/executorch": ["ExecuTorch"],
@@ -23,6 +23,7 @@ export const IS_INCREASING_METRIC_VALUE_GOOD: { [k: string]: boolean } = {
2323
token_per_sec: true,
2424
flops_utilization: true,
2525
"compilation_time(s)": false,
26+
speedup: true,
2627
};
2728
export const METRIC_DISPLAY_SHORT_HEADERS: { [k: string]: string } = {
2829
"memory_bandwidth(GB/s)": "Bandwidth",
@@ -40,9 +41,9 @@ export const RELATIVE_THRESHOLD = 0.05;
4041
export interface LLMsBenchmarkData {
4142
granularity_bucket: string;
4243
model: string;
43-
backend?: string;
44+
backend: string;
4445
workflow_id: number;
45-
job_id?: number;
46+
job_id: number;
4647
metric: string;
4748
actual: number;
4849
target: number;

torchci/lib/benchmark/aoUtils.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1+
import { LLMsBenchmarkData } from "components/benchmark/llms/common";
12
import { BenchmarkData, CompilerPerformanceData } from "lib/types";
23

4+
export const TORCHAO_REPO = "pytorch/ao";
5+
// TODO (huydhn): Find a better way to abstract this baseline concept, for example,
6+
// this could be dtype noquant for TorchAO, or eager config for inductor
7+
export const TORCHAO_BASELINE = "noquant";
8+
// TODO (huydhn): The following are TorchAO speedup metrics. Check with ao team to
9+
// see if this information could be codified on the benchmark instead of keeping it
10+
// here on the dashboard
11+
const SPEEDUP_METRICS = ["tok/s", "time_ms(avg)", "time_s(avg)", "img_s(avg)"];
12+
313
// TODO (huydhn): Use this function to convert the generic benchmark data to the old
414
// CompilerPerformanceData format. This is needed until the TorchInductor dashboard
515
// is migrated to the new format
@@ -43,3 +53,50 @@ export function convertToCompilerPerformanceData(data: BenchmarkData[]) {
4353

4454
return Object.values(convertData);
4555
}
56+
57+
export function computeSpeedup(repoName: string, data: LLMsBenchmarkData[]) {
58+
if (repoName !== TORCHAO_REPO) {
59+
return data;
60+
}
61+
62+
const baselineMetrics: { [key: string]: LLMsBenchmarkData } = {};
63+
data.forEach((r: LLMsBenchmarkData) => {
64+
if (r.dtype !== TORCHAO_BASELINE) {
65+
return;
66+
}
67+
68+
const k = `${r.workflow_id} ${r.job_id} ${r.model} ${r.metric} ${r.device} ${r.arch}`;
69+
baselineMetrics[k] = r;
70+
});
71+
72+
const withSpeedup: LLMsBenchmarkData[] = [];
73+
data.forEach((r: LLMsBenchmarkData) => {
74+
if (r.dtype === TORCHAO_BASELINE) {
75+
return;
76+
}
77+
78+
if (SPEEDUP_METRICS.includes(r.metric)) {
79+
const k = `${r.workflow_id} ${r.job_id} ${r.model} ${r.metric} ${r.device} ${r.arch}`;
80+
if (
81+
k in baselineMetrics &&
82+
baselineMetrics[k].actual !== 0 &&
83+
r.actual !== 0
84+
) {
85+
const speedup = r.metric.includes("time")
86+
? baselineMetrics[k].actual / r.actual
87+
: r.actual / baselineMetrics[k].actual;
88+
89+
withSpeedup.push({
90+
...r,
91+
metric: "speedup",
92+
actual: Number(speedup.toFixed(4)),
93+
target: 0,
94+
});
95+
}
96+
}
97+
98+
withSpeedup.push(r);
99+
});
100+
101+
return withSpeedup;
102+
}

torchci/lib/benchmark/llmUtils.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ export function combineLeftAndRight(
118118
row["metadata"]["r"] ?? (hasR ? record["r"]["job_id"] : undefined);
119119
}
120120

121+
if (dtype !== "") {
122+
row["dtype"] = dtype;
123+
}
124+
121125
if (backend !== "") {
122126
row["backend"] = backend;
123127
}

torchci/pages/benchmark/llms.tsx

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import CopyLink from "components/CopyLink";
2121
import GranularityPicker from "components/GranularityPicker";
2222
import { Granularity } from "components/metrics/panels/TimeSeriesPanel";
2323
import dayjs from "dayjs";
24+
import { computeSpeedup, TORCHAO_BASELINE } from "lib/benchmark/aoUtils";
2425
import { useBenchmark } from "lib/benchmark/llmUtils";
2526
import { fetcher } from "lib/GeneralUtils";
2627
import { BranchAndCommit } from "lib/types";
@@ -81,22 +82,29 @@ function Report({
8182
);
8283
}
8384

85+
const lDataWithSpeedup = computeSpeedup(repoName, lData);
86+
const rDataWithSpeedup = computeSpeedup(repoName, rData);
87+
88+
if (repoName === "pytorch/ao") {
89+
metricNames = ["speedup", ...metricNames];
90+
}
91+
8492
return (
8593
<div>
8694
<CommitPanel
8795
repoName={repoName}
8896
lBranchAndCommit={{
8997
...rBranchAndCommit,
9098
date:
91-
rData !== undefined && rData.length !== 0
92-
? rData[0].granularity_bucket
99+
rDataWithSpeedup !== undefined && rDataWithSpeedup.length !== 0
100+
? rDataWithSpeedup[0].granularity_bucket
93101
: undefined,
94102
}}
95103
rBranchAndCommit={{
96104
...lBranchAndCommit,
97105
date:
98-
lData !== undefined && lData.length !== 0
99-
? lData[0].granularity_bucket
106+
lDataWithSpeedup !== undefined && lDataWithSpeedup.length !== 0
107+
? lDataWithSpeedup[0].granularity_bucket
100108
: undefined,
101109
}}
102110
workflowName={"inductor-micro-benchmark"}
@@ -106,6 +114,7 @@ function Report({
106114
<GraphPanel
107115
queryParams={queryParams}
108116
granularity={granularity}
117+
repoName={repoName}
109118
modelName={modelName}
110119
backendName={backendName}
111120
dtypeName={dtypeName}
@@ -124,11 +133,11 @@ function Report({
124133
metricNames={metricNames}
125134
lPerfData={{
126135
...lBranchAndCommit,
127-
data: lData,
136+
data: lDataWithSpeedup,
128137
}}
129138
rPerfData={{
130139
...rBranchAndCommit,
131-
data: rData,
140+
data: rDataWithSpeedup,
132141
}}
133142
/>
134143
</div>
@@ -237,7 +246,12 @@ export default function Page() {
237246
const queryName = "oss_ci_benchmark_names";
238247
const queryParams = {
239248
deviceArch: deviceName === DEFAULT_DEVICE_NAME ? "" : deviceName,
240-
dtypes: dtypeName === DEFAULT_DTYPE_NAME ? [] : [dtypeName],
249+
dtypes:
250+
dtypeName === DEFAULT_DTYPE_NAME
251+
? []
252+
: repoName !== "pytorch/ao"
253+
? [dtypeName]
254+
: [dtypeName, TORCHAO_BASELINE],
241255
excludedMetrics: EXCLUDED_METRICS,
242256
benchmarks: REPO_TO_BENCHMARKS[repoName],
243257
granularity: granularity,
@@ -274,7 +288,10 @@ export default function Page() {
274288
];
275289
const dtypeNames: string[] = _.compact([
276290
DEFAULT_DTYPE_NAME,
277-
...(_.uniq(data.map((r: any) => r.dtype)) as string[]),
291+
..._.filter(
292+
_.uniq(data.map((r: any) => r.dtype)) as string[],
293+
(r: string) => r !== TORCHAO_BASELINE
294+
),
278295
]);
279296
const metricNames: string[] = _.uniq(data.map((r: any) => r.metric));
280297

@@ -372,7 +389,6 @@ export default function Page() {
372389
useClickHouse={true}
373390
/>
374391
</Stack>
375-
376392
<Report
377393
queryParams={queryParams}
378394
startTime={startTime}

0 commit comments

Comments
 (0)