File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed
Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change 44
55import copy
66import json
7+ import math
78import os
89from typing import Any , Iterator , Literal
910
@@ -91,19 +92,24 @@ class MetricInfo(BaseModel):
9192 as output by CmdStan"""
9293
9394 chain_id : int = Field (gt = 0 )
94- stepsize : float = Field ( gt = 0 )
95+ stepsize : float
9596 metric_type : Literal ["diag_e" , "dense_e" , "unit_e" ]
9697 inv_metric : np .ndarray
9798
9899 model_config = {"arbitrary_types_allowed" : True }
99100
100101 @field_validator ("inv_metric" , mode = "before" )
101- def convert_inv_metric ( # pylint: disable=no-self-argument
102- cls , v : Any
103- ) -> np .ndarray :
104-
102+ @classmethod
103+ def convert_inv_metric (cls , v : Any ) -> np .ndarray :
105104 return np .asarray (v )
106105
106+ @field_validator ("stepsize" )
107+ @classmethod
108+ def validate_stepsize (cls , v : float ) -> float :
109+ if not math .isnan (v ) and v <= 0 :
110+ raise ValueError ("stepsize must be greater than 0 or NaN" )
111+ return v
112+
107113 @model_validator (mode = "after" )
108114 def validate_inv_metric_shape (self ) -> MetricInfo :
109115 if (
You can’t perform that action at this time.
0 commit comments