44
55from datetime import datetime
66
7- import json
87import logging
98import numpy as np
109import os
1110import pickle
12- import sys
1311import tempfile
1412import time
1513import uuid
1614
1715import tensorflow as tf
16+ from ray .tune .logger import UnifiedLogger
1817from ray .tune .result import TrainingResult
1918
20- if sys .version_info [0 ] == 2 :
21- import cStringIO as StringIO
22- elif sys .version_info [0 ] == 3 :
23- import io as StringIO
24-
2519logger = logging .getLogger (__name__ )
2620logger .setLevel (logging .INFO )
2721
@@ -39,24 +33,18 @@ class Agent(object):
3933 """
4034
4135 _allow_unknown_configs = False
36+ _default_logdir = "/tmp/ray"
4237
4338 def __init__ (
44- self , env_creator , config , local_dir = '/tmp/ray' ,
45- upload_dir = None , experiment_tag = None ):
39+ self , env_creator , config , logger_creator = None ):
4640 """Initialize an RLLib agent.
4741
4842 Args:
4943 env_creator (str|func): Name of the OpenAI gym environment to train
5044 against, or a function that creates such an env.
51- config (obj): Algorithm-specific configuration data.
52- local_dir (str): Directory where results and temporary files will
53- be placed.
54- upload_dir (str): Optional remote URI like s3://bucketname/ where
55- results will be uploaded.
56- experiment_tag (str): Optional string containing extra metadata
57- about the experiment, e.g. a summary of parameters. This string
58- will be included in the logdir path and when displaying agent
59- progress.
45+ config (dict): Algorithm-specific configuration data.
46+ logger_creator (func): Function that creates a ray.tune.Logger
47+ object. If unspecified, a default logger is created.
6048 """
6149 self ._initialize_ok = False
6250 self ._experiment_id = uuid .uuid4 ().hex
@@ -79,40 +67,20 @@ def __init__(
7967 "Unknown agent config `{}`, "
8068 "all agent configs: {}" .format (k , self .config .keys ()))
8169 self .config .update (config )
82- self .config .update ({
83- "experiment_tag" : experiment_tag ,
84- "alg" : self ._agent_name ,
85- "env_name" : env_name ,
86- "experiment_id" : self ._experiment_id ,
87- })
88-
89- logdir_suffix = "{}_{}_{}" .format (
90- env_name ,
91- self ._agent_name ,
92- experiment_tag or datetime .today ().strftime ("%Y-%m-%d_%H-%M-%S" ))
93-
94- if not os .path .exists (local_dir ):
95- os .makedirs (local_dir )
9670
97- self .logdir = tempfile .mkdtemp (prefix = logdir_suffix , dir = local_dir )
98-
99- if upload_dir :
100- log_upload_uri = os .path .join (upload_dir , logdir_suffix )
71+ if logger_creator :
72+ self ._result_logger = logger_creator (self .config )
73+ self .logdir = self ._result_logger .logdir
10174 else :
102- log_upload_uri = None
103-
104- # TODO(ekl) consider inlining config into the result jsons
105- config_out = os .path .join (self .logdir , "config.json" )
106- with open (config_out , "w" ) as f :
107- json .dump (self .config , f , sort_keys = True , cls = _Encoder )
108- logger .info (
109- "%s agent created with logdir '%s' and upload uri '%s'" ,
110- self .__class__ .__name__ , self .logdir , log_upload_uri )
111-
112- self ._result_logger = _Logger (
113- os .path .join (self .logdir , "result.json" ),
114- log_upload_uri and os .path .join (log_upload_uri , "result.json" ))
115- self ._file_writer = tf .summary .FileWriter (self .logdir )
75+ logdir_suffix = "{}_{}_{}" .format (
76+ env_name ,
77+ self ._agent_name ,
78+ datetime .today ().strftime ("%Y-%m-%d_%H-%M-%S" ))
79+ if not os .path .exists (self ._default_logdir ):
80+ os .makedirs (self ._default_logdir )
81+ self .logdir = tempfile .mkdtemp (
82+ prefix = logdir_suffix , dir = self ._default_logdir )
83+ self ._result_logger = UnifiedLogger (self .config , self .logdir , None )
11684
11785 self ._iteration = 0
11886 self ._time_total = 0.0
@@ -161,29 +129,10 @@ def train(self):
161129 pid = os .getpid (),
162130 hostname = os .uname ()[1 ])
163131
164- self ._log_result (result )
132+ self ._result_logger . on_result (result )
165133
166134 return result
167135
168- def _log_result (self , result ):
169- """Appends the given result to this agent's log dir."""
170-
171- # We need to use a custom json serializer class so that NaNs get
172- # encoded as null as required by Athena.
173- json .dump (result ._asdict (), self ._result_logger , cls = _Encoder )
174- self ._result_logger .write ("\n " )
175- attrs_to_log = [
176- "time_this_iter_s" , "mean_loss" , "mean_accuracy" ,
177- "episode_reward_mean" , "episode_len_mean" ]
178- values = []
179- for attr in attrs_to_log :
180- if getattr (result , attr ) is not None :
181- values .append (tf .Summary .Value (
182- tag = "ray/tune/{}" .format (attr ),
183- simple_value = getattr (result , attr )))
184- train_stats = tf .Summary (value = values )
185- self ._file_writer .add_summary (train_stats , result .training_iteration )
186-
187136 def save (self ):
188137 """Saves the current model state to a checkpoint.
189138
@@ -214,7 +163,7 @@ def restore(self, checkpoint_path):
214163 def stop (self ):
215164 """Releases all resources used by this agent."""
216165
217- self ._file_writer .close ()
166+ self ._result_logger .close ()
218167
219168 def compute_action (self , observation ):
220169 """Computes an action using the current trained policy."""
@@ -255,61 +204,6 @@ def _restore(self):
255204 raise NotImplementedError
256205
257206
258- class _Encoder (json .JSONEncoder ):
259-
260- def __init__ (self , nan_str = "null" , ** kwargs ):
261- super (_Encoder , self ).__init__ (** kwargs )
262- self .nan_str = nan_str
263-
264- def iterencode (self , o , _one_shot = False ):
265- if self .ensure_ascii :
266- _encoder = json .encoder .encode_basestring_ascii
267- else :
268- _encoder = json .encoder .encode_basestring
269-
270- def floatstr (o , allow_nan = self .allow_nan , nan_str = self .nan_str ):
271- return repr (o ) if not np .isnan (o ) else nan_str
272-
273- _iterencode = json .encoder ._make_iterencode (
274- None , self .default , _encoder , self .indent , floatstr ,
275- self .key_separator , self .item_separator , self .sort_keys ,
276- self .skipkeys , _one_shot )
277- return _iterencode (o , 0 )
278-
279- def default (self , value ):
280- if np .isnan (value ):
281- return None
282- if np .issubdtype (value , float ):
283- return float (value )
284- if np .issubdtype (value , int ):
285- return int (value )
286-
287-
288- class _Logger (object ):
289- """Writing small amounts of data to S3 with real-time updates.
290- """
291-
292- def __init__ (self , local_file , uri = None ):
293- self .local_out = open (local_file , "w" )
294- self .result_buffer = StringIO .StringIO ()
295- self .uri = uri
296- if self .uri :
297- import smart_open
298- self .smart_open = smart_open .smart_open
299-
300- def write (self , b ):
301- self .local_out .write (b )
302- self .local_out .flush ()
303- # TODO(pcm): At the moment we are writing the whole results output from
304- # the beginning in each iteration. This will write O(n^2) bytes where n
305- # is the number of bytes printed so far. Fix this! This should at least
306- # only write the last 5MBs (S3 chunksize).
307- if self .uri :
308- with self .smart_open (self .uri , "w" ) as f :
309- self .result_buffer .write (b )
310- f .write (self .result_buffer .getvalue ())
311-
312-
313207class _MockAgent (Agent ):
314208 """Mock agent for use in tests"""
315209
0 commit comments