diff --git a/advanced_source/semi_structured_sparse.py b/advanced_source/semi_structured_sparse.py index e4bca79b9a..be563db384 100644 --- a/advanced_source/semi_structured_sparse.py +++ b/advanced_source/semi_structured_sparse.py @@ -55,6 +55,9 @@ from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor from torch.utils.benchmark import Timer +# the following line may need to be enabled to see a speedup +# SparseSemiStructuredTensor._FORCE_CUTLASS = True + # mask Linear weight to be 2:4 sparse mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool() linear = torch.nn.Linear(10240, 3072).half().cuda().eval()