@@ -207,6 +207,33 @@ def test_model_info(self):
207207 self .assertIn ('theta' , model_info_include ['parameters' ])
208208 self .assertIn ('included_files' , model_info_include )
209209
210+ def test_compile_with_bad_includes (self ):
211+ # Ensure compilation fails if we break an included file.
212+ stan_file = os .path .join (DATAFILES_PATH , "add_one_model.stan" )
213+ exe_file = os .path .splitext (stan_file )[0 ] + EXTENSION
214+ if os .path .isfile (exe_file ):
215+ os .unlink (exe_file )
216+ with tempfile .TemporaryDirectory () as include_path :
217+ include_source = os .path .join (
218+ DATAFILES_PATH , "include-path" , "add_one_function.stan"
219+ )
220+ include_target = os .path .join (include_path , "add_one_function.stan" )
221+ shutil .copy (include_source , include_target )
222+ model = CmdStanModel (
223+ stan_file = stan_file ,
224+ compile = False ,
225+ stanc_options = {"include-paths" : [include_path ]},
226+ )
227+ with LogCapture (level = logging .INFO ) as log :
228+ model .compile ()
229+ log .check_present (
230+ ('cmdstanpy' , 'INFO' , StringComparison ('compiling stan file' ))
231+ )
232+ with open (include_target , "w" ) as fd :
233+ fd .write ("gobbledygook" )
234+ with pytest .raises (ValueError , match = "Failed to get source info" ):
235+ model .compile ()
236+
210237 def test_compile_with_includes (self ):
211238 getmtime = os .path .getmtime
212239 configs = [
@@ -215,6 +242,9 @@ def test_compile_with_includes(self):
215242 ]
216243 for stan_file , include_paths in configs :
217244 stan_file = os .path .join (DATAFILES_PATH , stan_file )
245+ exe_file = os .path .splitext (stan_file )[0 ] + EXTENSION
246+ if os .path .isfile (exe_file ):
247+ os .unlink (exe_file )
218248 include_paths = [
219249 os .path .join (DATAFILES_PATH , path ) for path in include_paths
220250 ]
@@ -348,6 +378,10 @@ def test_model_syntax_error(self):
348378 with self .assertRaisesRegex (ValueError , r'.*Syntax error.*' ):
349379 CmdStanModel (stan_file = stan )
350380
381+ def test_model_syntax_error_without_compile (self ):
382+ stan = os .path .join (DATAFILES_PATH , 'bad_syntax.stan' )
383+ CmdStanModel (stan_file = stan , compile = False )
384+
351385 def test_repr (self ):
352386 model = CmdStanModel (stan_file = BERN_STAN )
353387 model_repr = repr (model )
0 commit comments