@@ -530,7 +530,55 @@ def torch_export_rewrite(
530530 :param rewrite_methods: methods to rewrite, if not empty, the function may try
531531 to discover them, a method is defined by its class (a type) and its name
532532 if the class is local, by itself otherwise
533- :param verbose: verbosity, up to 10, 10 shows the rewritten code
533+ :param verbose: verbosity, up to 10, 10 shows the rewritten code,
534+ ``verbose=1`` shows the rewritten function,
535+ ``verbose=2`` shows the rewritten code as well
536+
537+ Example:
538+
539+ .. code-block:: python
540+
541+ class Model(torch.nn.Module):
542+ def forward(self, x, y):
543+ if x.sum() > 0:
544+ return x + y
545+ else:
546+ return torch.abs(x) + y + 1
547+
548+ model = Model()
549+ x, y = torch.rand((4, 5)), torch.rand((4, 5))
550+ DYN = torch.export.Dim.DYNAMIC
551+ ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
552+ with torch_export_rewrite(rewrite_methods=[(Model, "forward")]):
553+ ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
554+
555+ If the method to rewrite is not local, then the following can be used:
556+
557+ .. code-block:: python
558+
559+ with torch_export_rewrite(rewrite_methods=[Model.forward]):
560+ ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
561+
562+ Functions (if not local) can also be rewritten:
563+
564+ .. code-block:: python
565+
566+ def outside(x, y):
567+ if x.sum() > 0:
568+ return x + y
569+ else:
570+ return torch.abs(x) + y + 1
571+
572+ class Model(torch.nn.Module):
573+ def forward(self, x, y):
574+ return outside(x, y)
575+
576+ model = Model()
577+ x, y = torch.rand((4, 5)), torch.rand((4, 5))
578+ DYN = torch.export.Dim.DYNAMIC
579+ ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
580+ with torch_export_rewrite(rewrite_methods=[outside]):
581+ ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
534582 """
535583 assert (
536584 rewrite_methods
0 commit comments