@@ -81,6 +81,61 @@ def test_log_dict(self: TensorBoardLoggerTest) -> None:
81
81
)
82
82
self .assertEqual (tensor_tag .step , 1 )
83
83
84
+ def test_log_histogram_raw (self : TensorBoardLoggerTest ) -> None :
85
+ with tempfile .TemporaryDirectory () as log_dir :
86
+ logger = TensorBoardLogger (path = log_dir )
87
+
88
+ # generate a histogram with 4 bins in the range [0, 1]
89
+ data_range = [0.0 , 1.0 ]
90
+ bucket_counts = [1 , 3 , 5 , 4 ]
91
+ bucket_width = (data_range [1 ] - data_range [0 ]) / len (bucket_counts )
92
+ bucket_limits = [
93
+ ix * bucket_width + data_range [0 ]
94
+ for ix in range (len (bucket_counts ) + 1 )
95
+ ]
96
+ bucket_centers = [
97
+ (lower + upper ) / 2
98
+ for lower , upper in zip (bucket_limits [:- 1 ], bucket_limits [1 :])
99
+ ]
100
+ # sum of the binned values
101
+ value_sum = float (
102
+ sum (
103
+ value * count for value , count in zip (bucket_centers , bucket_counts )
104
+ )
105
+ )
106
+
107
+ logger .log_histogram_raw (
108
+ "histogram_raw" ,
109
+ min = 0 ,
110
+ max = 1 ,
111
+ num = sum (bucket_counts ),
112
+ sum = value_sum ,
113
+ sum_squares = value_sum ** 2 ,
114
+ bucket_limits = bucket_limits ,
115
+ # add an extra leading 0 to match the format of the histogram_raw
116
+ bucket_counts = [0 ] + bucket_counts ,
117
+ )
118
+ logger .close ()
119
+
120
+ acc = EventAccumulator (log_dir )
121
+ acc .Reload ()
122
+
123
+ # check that the histogram is logged correctly
124
+ self .assertIn ("histogram_raw" , acc .Tags ()["histograms" ])
125
+ # ensure that we logged exactly one histogram
126
+ self .assertEqual (len (acc .Histograms ("histogram_raw" )), 1 )
127
+ histogram_event = acc .Histograms ("histogram_raw" )[0 ]
128
+ histogram_value = histogram_event .histogram_value
129
+ # check that the histogram is logged correctly
130
+ self .assertEqual (histogram_value .min , 0 )
131
+ self .assertEqual (histogram_value .max , 1 )
132
+ self .assertEqual (histogram_value .num , sum (bucket_counts ))
133
+ self .assertEqual (histogram_value .sum , value_sum )
134
+ self .assertEqual (histogram_value .sum_squares , value_sum ** 2 )
135
+ self .assertListEqual (histogram_value .bucket_limit , bucket_limits )
136
+ self .assertListEqual (histogram_value .bucket [1 :], bucket_counts )
137
+ self .assertEqual (histogram_value .bucket [0 ], 0 )
138
+
84
139
def test_log_text (self : TensorBoardLoggerTest ) -> None :
85
140
with tempfile .TemporaryDirectory () as log_dir :
86
141
logger = TensorBoardLogger (path = log_dir )
0 commit comments