6969 DISABLED : UPSTREAM_DISABLED ,
7070}
7171
72+ WORKER_STATE_DISABLED = 'disabled'
73+ WORKER_STATE_ACTIVE = 'active'
74+
7275TASK_FAMILY_RE = re .compile (r'([^(_]+)[(_]' )
7376
7477RPC_METHODS = {}
@@ -319,6 +322,17 @@ def is_trivial_worker(self, state):
319322 def assistant (self ):
320323 return self .info .get ('assistant' , False )
321324
325+ @property
326+ def enabled (self ):
327+ return not self .disabled
328+
329+ @property
330+ def state (self ):
331+ if self .enabled :
332+ return WORKER_STATE_ACTIVE
333+ else :
334+ return WORKER_STATE_DISABLED
335+
322336 def __str__ (self ):
323337 return self .id
324338
@@ -527,7 +541,7 @@ def get_active_workers(self, last_active_lt=None, last_get_work_gt=None):
527541 for worker in six .itervalues (self ._active_workers ):
528542 if last_active_lt is not None and worker .last_active >= last_active_lt :
529543 continue
530- last_get_work = getattr ( worker , ' last_get_work' , None )
544+ last_get_work = worker . last_get_work
531545 if last_get_work_gt is not None and (
532546 last_get_work is None or last_get_work <= last_get_work_gt ):
533547 continue
@@ -554,10 +568,10 @@ def _remove_workers_from_tasks(self, workers, remove_stakeholders=True):
554568 task .stakeholders .difference_update (workers )
555569 task .workers .difference_update (workers )
556570
557- def disable_workers (self , workers ):
558- self ._remove_workers_from_tasks (workers , remove_stakeholders = False )
559- for worker in workers :
560- self .get_worker (worker ).disabled = True
571+ def disable_workers (self , worker_ids ):
572+ self ._remove_workers_from_tasks (worker_ids , remove_stakeholders = False )
573+ for worker_id in worker_ids :
574+ self .get_worker (worker_id ).disabled = True
561575
562576
563577class Scheduler (object ):
@@ -623,13 +637,12 @@ def _prune_tasks(self):
623637
624638 self ._state .inactivate_tasks (remove_tasks )
625639
626- def update (self , worker_id , worker_reference = None , get_work = False ):
627- """
628- Keep track of whenever the worker was last active.
629- """
640+ def _update_worker (self , worker_id , worker_reference = None , get_work = False ):
641+ # Keep track of whenever the worker was last active.
642+ # For convenience also return the worker object.
630643 worker = self ._state .get_worker (worker_id )
631644 worker .update (worker_reference , get_work = get_work )
632- return not getattr ( worker , 'disabled' , False )
645+ return worker
633646
634647 def _update_priority (self , task , prio , worker ):
635648 """
@@ -663,10 +676,10 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
663676 """
664677 assert worker is not None
665678 worker_id = worker
666- worker_enabled = self .update (worker_id )
679+ worker = self ._update_worker (worker_id )
667680 retry_policy = self ._generate_retry_policy (retry_policy_dict )
668681
669- if worker_enabled :
682+ if worker . enabled :
670683 _default_task = self ._make_task (
671684 task_id = task_id , status = PENDING , deps = deps , resources = resources ,
672685 priority = priority , family = family , module = module , params = params ,
@@ -676,7 +689,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
676689
677690 task = self ._state .get_task (task_id , setdefault = _default_task )
678691
679- if task is None or (task .status != RUNNING and not worker_enabled ):
692+ if task is None or (task .status != RUNNING and not worker . enabled ):
680693 return
681694
682695 # for setting priority, we'll sometimes create tasks with unset family and params
@@ -728,7 +741,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
728741 if resources is not None :
729742 task .resources = resources
730743
731- if worker_enabled and not assistant :
744+ if worker . enabled and not assistant :
732745 task .stakeholders .add (worker_id )
733746
734747 # Task dependencies might not exist yet. Let's create dummy tasks for them for now.
@@ -743,7 +756,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
743756 # before we know their retry_policy, we always set it here
744757 task .retry_policy = retry_policy
745758
746- if runnable and status != FAILED and worker_enabled :
759+ if runnable and status != FAILED and worker . enabled :
747760 task .workers .add (worker_id )
748761 self ._state .get_worker (worker_id ).tasks .add (task )
749762 task .runnable = runnable
@@ -837,8 +850,19 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
837850
838851 assert worker is not None
839852 worker_id = worker
840- # Return remaining tasks that have no FAILED descendants
841- self .update (worker_id , {'host' : host }, get_work = True )
853+ worker = self ._update_worker (
854+ worker_id ,
855+ worker_reference = {'host' : host },
856+ get_work = True )
857+ if not worker .enabled :
858+ reply = {'n_pending_tasks' : 0 ,
859+ 'running_tasks' : [],
860+ 'task_id' : None ,
861+ 'n_unique_pending' : 0 ,
862+ 'worker_state' : worker .state ,
863+ }
864+ return reply
865+
842866 if assistant :
843867 self .add_worker (worker_id , [('assistant' , assistant )])
844868
@@ -942,7 +966,9 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
942966 reply = {'n_pending_tasks' : locally_pending_tasks ,
943967 'running_tasks' : running_tasks ,
944968 'task_id' : None ,
945- 'n_unique_pending' : n_unique_pending }
969+ 'n_unique_pending' : n_unique_pending ,
970+ 'worker_state' : worker .state ,
971+ }
946972
947973 if len (batched_tasks ) > 1 :
948974 batch_string = '|' .join (task .id for task in batched_tasks )
@@ -976,7 +1002,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
9761002 @rpc_method (attempts = 1 )
9771003 def ping (self , ** kwargs ):
9781004 worker_id = kwargs ['worker' ]
979- self .update (worker_id )
1005+ self ._update_worker (worker_id )
9801006
9811007 def _upstream_status (self , task_id , upstream_status_table ):
9821008 if task_id in upstream_status_table :
@@ -1142,7 +1168,7 @@ def filter_func(t):
11421168 return all (term in t .pretty_id for term in terms )
11431169 for task in filter (filter_func , self ._state .get_active_tasks (status )):
11441170 if task .status != PENDING or not upstream_status or upstream_status == self ._upstream_status (task .id , upstream_status_table ):
1145- serialized = self ._serialize_task (task .id , False )
1171+ serialized = self ._serialize_task (task .id , include_deps = False )
11461172 result [task .id ] = serialized
11471173 if limit and len (result ) > (max_shown_tasks or self ._config .max_shown_tasks ):
11481174 return {'num_tasks' : len (result )}
@@ -1162,7 +1188,8 @@ def worker_list(self, include_running=True, **kwargs):
11621188 dict (
11631189 name = worker .id ,
11641190 last_active = worker .last_active ,
1165- started = getattr (worker , 'started' , None ),
1191+ started = worker .started ,
1192+ state = worker .state ,
11661193 first_task_display_name = self ._first_task_display_name (worker ),
11671194 ** worker .info
11681195 ) for worker in self ._state .get_active_workers ()]
@@ -1173,7 +1200,7 @@ def worker_list(self, include_running=True, **kwargs):
11731200 num_uniques = collections .defaultdict (int )
11741201 for task in self ._state .get_pending_tasks ():
11751202 if task .status == RUNNING and task .worker_running :
1176- running [task .worker_running ][task .id ] = self ._serialize_task (task .id , False )
1203+ running [task .worker_running ][task .id ] = self ._serialize_task (task .id , include_deps = False )
11771204 elif task .status == PENDING :
11781205 for worker in task .workers :
11791206 num_pending [worker ] += 1
@@ -1204,7 +1231,7 @@ def resource_list(self):
12041231 for task in self ._state .get_running_tasks ():
12051232 if task .status == RUNNING and task .resources :
12061233 for resource , amount in six .iteritems (task .resources ):
1207- consumers [resource ][task .id ] = self ._serialize_task (task .id , False )
1234+ consumers [resource ][task .id ] = self ._serialize_task (task .id , include_deps = False )
12081235 for resource in resources :
12091236 tasks = consumers [resource ['name' ]]
12101237 resource ['num_consumer' ] = len (tasks )
@@ -1235,7 +1262,7 @@ def task_search(self, task_str, **kwargs):
12351262 result = collections .defaultdict (dict )
12361263 for task in self ._state .get_active_tasks ():
12371264 if task .id .find (task_str ) != - 1 :
1238- serialized = self ._serialize_task (task .id , False )
1265+ serialized = self ._serialize_task (task .id , include_deps = False )
12391266 result [task .status ][task .id ] = serialized
12401267 return result
12411268
0 commit comments