TorchProfile counts the number of MACs (multiply-accumulate operations) in a PyTorch model. It works by tracing the computation graph with torch.jit.trace, making it more accurate than hook-based profilers and more general than ONNX-based ones.
pip install torchprofileimport torch
from transformers import AutoModel
from torchprofile import profile_macs
model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B").eval()
inputs = torch.randint(0, model.config.vocab_size, (1, 128))
macs = profile_macs(model, inputs)
print(f"{macs / 1e9:.2f} GMACs")To get a per-operator breakdown, pass reduction=None:
results = profile_macs(model, inputs, reduction=None)
for node, macs in results.items():
if macs > 0:
print(f"{node.scope:40s} {node.operator:30s} {macs / 1e6:>8.2f} MMACs")This repository is released under the MIT license. See LICENSE for additional details.