Skip to content

zhijian-liu/torchprofile

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

138 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TorchProfile

PyPI License PyTorch

TorchProfile counts the number of MACs (multiply-accumulate operations) in a PyTorch model. It works by tracing the computation graph with torch.jit.trace, making it more accurate than hook-based profilers and more general than ONNX-based ones.

Installation

pip install torchprofile

Quick Start

import torch
from transformers import AutoModel
from torchprofile import profile_macs

model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B").eval()
inputs = torch.randint(0, model.config.vocab_size, (1, 128))

macs = profile_macs(model, inputs)
print(f"{macs / 1e9:.2f} GMACs")

To get a per-operator breakdown, pass reduction=None:

results = profile_macs(model, inputs, reduction=None)
for node, macs in results.items():
    if macs > 0:
        print(f"{node.scope:40s} {node.operator:30s} {macs / 1e6:>8.2f} MMACs")

License

This repository is released under the MIT license. See LICENSE for additional details.

About

Count the MACs / FLOPs of PyTorch models

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages