@@ -42,6 +42,10 @@ class RaySampler(Sampler):
4242 the Worker interface.
4343 worker_args (dict or None): Additional arguments that should be passed
4444 to the worker.
45+ n_gpus (int or float): Number of GPUs to to use in total for sampling.
46+ If `n_workers` is not a power of two, this may need to be set
47+ slightly below the true value, since `n_workers / n_gpus` gpus are
48+ allocated to each worker.
4549
4650 """
4751
@@ -56,8 +60,8 @@ def __init__(
5660 seed = get_seed (),
5761 n_workers = psutil .cpu_count (logical = False ),
5862 worker_class = DefaultWorker ,
59- worker_args = None ):
60- # pylint: disable=super-init-not-called
63+ worker_args = None ,
64+ n_gpus = 0 ):
6165 if not ray .is_initialized ():
6266 ray .init (log_to_driver = False , ignore_reinit_error = True )
6367 if worker_factory is None and max_episode_length is None :
@@ -73,7 +77,8 @@ def __init__(
7377 n_workers = n_workers ,
7478 worker_class = worker_class ,
7579 worker_args = worker_args )
76- self ._sampler_worker = ray .remote (SamplerWorker )
80+ remote_wrapper = ray .remote (num_gpus = n_gpus / n_workers )
81+ self ._sampler_worker = remote_wrapper (SamplerWorker )
7782 self ._agents = agents
7883 self ._envs = self ._worker_factory .prepare_worker_messages (envs )
7984 self ._all_workers = defaultdict (None )
@@ -103,7 +108,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
103108 Sampler: An instance of `cls`.
104109
105110 """
106- return cls (agents , envs , worker_factory = worker_factory )
111+ return cls (agents ,
112+ envs ,
113+ worker_factory = worker_factory ,
114+ n_workers = worker_factory .n_workers )
107115
108116 def start_worker (self ):
109117 """Initialize a new ray worker."""
0 commit comments