@@ -1436,33 +1436,23 @@ def enable_group_offload(
14361436 f"The following modules are not present in pipeline: { ', ' .join (unknown )} . Ignore if this is expected."
14371437 )
14381438
1439+ group_offload_kwargs = {
1440+ "onload_device" : onload_device ,
1441+ "offload_device" : offload_device ,
1442+ "offload_type" : offload_type ,
1443+ "num_blocks_per_group" : num_blocks_per_group ,
1444+ "non_blocking" : non_blocking ,
1445+ "use_stream" : use_stream ,
1446+ "record_stream" : record_stream ,
1447+ "low_cpu_mem_usage" : low_cpu_mem_usage ,
1448+ "offload_to_disk_path" : offload_to_disk_path ,
1449+ }
14391450 for name , component in self .components .items ():
14401451 if name not in exclude_modules and isinstance (component , torch .nn .Module ):
14411452 if hasattr (component , "enable_group_offload" ):
1442- component .enable_group_offload (
1443- onload_device = onload_device ,
1444- offload_device = offload_device ,
1445- offload_type = offload_type ,
1446- num_blocks_per_group = num_blocks_per_group ,
1447- non_blocking = non_blocking ,
1448- use_stream = use_stream ,
1449- record_stream = record_stream ,
1450- low_cpu_mem_usage = low_cpu_mem_usage ,
1451- offload_to_disk_path = offload_to_disk_path ,
1452- )
1453+ component .enable_group_offload (** group_offload_kwargs )
14531454 else :
1454- apply_group_offloading (
1455- module = component ,
1456- onload_device = onload_device ,
1457- offload_device = offload_device ,
1458- offload_type = offload_type ,
1459- num_blocks_per_group = num_blocks_per_group ,
1460- non_blocking = non_blocking ,
1461- use_stream = use_stream ,
1462- record_stream = record_stream ,
1463- low_cpu_mem_usage = low_cpu_mem_usage ,
1464- offload_to_disk_path = offload_to_disk_path ,
1465- )
1455+ apply_group_offloading (module = component , ** group_offload_kwargs )
14661456
14671457 if exclude_modules :
14681458 for module_name in exclude_modules :
0 commit comments