Skip to content

Commit 0a5233d

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added loglevel configuration for running benchmarks (#3171)
Summary: Pull Request resolved: #3171 - Added `--loglevel` argument to `benchmark_utils.py` for setting the log level. - Allow argument duplicates with a warning instead of giving an error. Reviewed By: TroyGarden Differential Revision: D77950216 fbshipit-source-id: 71f9538e40a8149b292116619f0b8217f2e63bc9
1 parent 51078e8 commit 0a5233d

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ def wrapper() -> Any:
477477
sig = inspect.signature(func)
478478
parser = argparse.ArgumentParser(func.__doc__)
479479

480+
# Add loglevel argument with current logger level as default
481+
parser.add_argument(
482+
"--loglevel",
483+
type=str,
484+
default=logging._levelToName[logger.level],
485+
help="Set the logging level (e.g. info, debug, warning, error)",
486+
)
487+
480488
seen_args = set() # track all --<name> we've added
481489

482490
for _name, param in sig.parameters.items():
@@ -487,7 +495,8 @@ def wrapper() -> Any:
487495
for f in fields(cls):
488496
arg_name = f.name
489497
if arg_name in seen_args:
490-
parser.error(f"Duplicate argument {arg_name}")
498+
logger.warning(f"WARNING: duplicate argument {arg_name}")
499+
continue
491500
seen_args.add(arg_name)
492501

493502
ftype = f.type
@@ -521,14 +530,20 @@ def wrapper() -> Any:
521530
parser.add_argument(f"--{arg_name}", **arg_kwargs)
522531

523532
args = parser.parse_args()
533+
logger.setLevel(logging.INFO)
524534

525535
# Build the dataclasses
526536
kwargs = {}
527537
for name, param in sig.parameters.items():
528538
cls = param.annotation
529539
if is_dataclass(cls):
530540
data = {f.name: getattr(args, f.name) for f in fields(cls)}
531-
kwargs[name] = cls(**data) # pyre-ignore [29]
541+
config_instance = cls(**data) # pyre-ignore [29]
542+
kwargs[name] = config_instance
543+
logger.info(config_instance)
544+
545+
loglevel = logging._nameToLevel[args.loglevel.upper()]
546+
logger.setLevel(loglevel)
532547

533548
return func(**kwargs)
534549

0 commit comments

Comments
 (0)