Skip to content

Commit 0899f34

Browse files
committed
Update torchao to 0.4.0 and fix GPU quantization tutorial
1 parent 01d2270 commit 0899f34

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,5 @@ iopath
6868
pygame==2.6.0
6969
pycocotools
7070
semilearn==0.3.2
71-
torchao==0.0.3
71+
torchao==0.4.0
7272
segment_anything==1.0

prototype_source/gpu_quantization_torchao_tutorial.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
#
4545

4646
import torch
47-
from torchao.quantization import change_linear_weights_to_int8_dqtensors
47+
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
4848
from segment_anything import sam_model_registry
4949
from torch.utils.benchmark import Timer
5050

@@ -156,9 +156,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
156156
# in memory bound situations where the benefit comes from loading less
157157
# weight data, rather than doing less computation. The torchao APIs:
158158
#
159-
# ``change_linear_weights_to_int8_dqtensors``,
160-
# ``change_linear_weights_to_int8_woqtensors`` or
161-
# ``change_linear_weights_to_int4_woqtensors``
159+
# ``int8_dynamic_activation_int8_weight()``,
160+
# ``int8_dynamic_activation_int8_semi_sparse_weight`` or
161+
# ``int8_dynamic_activation_int4_weight``
162162
#
163163
# can be used to easily apply the desired quantization technique and then
164164
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
@@ -185,7 +185,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
185185
model, image = get_sam_model(only_one_block, batchsize)
186186
model = model.to(torch.bfloat16)
187187
image = image.to(torch.bfloat16)
188-
change_linear_weights_to_int8_dqtensors(model)
188+
quantize_(model, int8_dynamic_activation_int8_weight())
189189
model_c = torch.compile(model, mode='max-autotune')
190190
quant_res = benchmark(model_c, image)
191191
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")

0 commit comments

Comments
 (0)