@@ -356,6 +356,100 @@ def test_api(self):
356356 self .assertEqual (variables , expression .get_variables ())
357357
358358
359+ class TestCounterfactual (unittest .TestCase ):
360+ """Tests for counterfactuals."""
361+
362+ def test_event_semantics (self ):
363+ """Check DSL semantics."""
364+ for expr , counterfactual_star , intervention_star in [
365+ (X @ X , None , False ),
366+ (- X @ X , False , False ),
367+ (+ X @ X , True , False ),
368+ (~ X @ X , True , False ),
369+ ]:
370+ with self .subTest (expr = expr .to_y0 ()):
371+ self .assertIsInstance (expr , CounterfactualVariable )
372+ self .assertEqual (counterfactual_star , expr .star )
373+ self .assertEqual (1 , len (expr .interventions ))
374+ self .assertEqual (intervention_star , expr .interventions [0 ].star )
375+
376+ def test_event_failures (self ):
377+ """Check for failure to determine tautology/inconsistent."""
378+ for expr in [
379+ # Opposite variable
380+ X @ Y ,
381+ X @ + Y ,
382+ X @ - Y ,
383+ X @ ~ Y ,
384+ # Same variable
385+ X @ X ,
386+ X @ + X ,
387+ X @ - X ,
388+ X @ ~ X ,
389+ ]:
390+ with self .subTest (expr = expr .to_y0 ()):
391+ self .assertIsInstance (expr , CounterfactualVariable )
392+ self .assertFalse (expr .is_event ())
393+ with self .assertRaises (ValueError ):
394+ expr .has_tautology ()
395+ with self .assertRaises (ValueError ):
396+ expr .is_inconsistent ()
397+
398+ def test_tautology (self ):
399+ """Check for tautologies."""
400+ for expr , status in [
401+ # Different Variable
402+ (~ X @ Y , False ),
403+ (~ X @ ~ Y , False ),
404+ # Same variable, self.star is False
405+ (- X @ X , True ),
406+ (- X @ - X , True ),
407+ (- X @ + X , False ),
408+ (- X @ ~ X , False ),
409+ # Same variable, self.star is True
410+ (+ X @ X , False ),
411+ (+ X @ - X , False ),
412+ (+ X @ + X , True ),
413+ (+ X @ ~ X , True ),
414+ # Same variable, self.star is True
415+ (~ X @ X , False ),
416+ (~ X @ - X , False ),
417+ (~ X @ + X , True ),
418+ (~ X @ ~ X , True ),
419+ ]:
420+ with self .subTest (expr = expr .to_y0 ()):
421+ self .assertIsInstance (expr , CounterfactualVariable )
422+ self .assertTrue (expr .is_event ())
423+ self .assertEqual (status , expr .has_tautology ())
424+
425+ def test_inconsistent (self ):
426+ """Check for tautologies."""
427+ for expr , status in [
428+ # Different Variable
429+ (~ X @ Y , False ),
430+ (~ X @ ~ Y , False ),
431+ # Same variable, self.star is False
432+ (- X @ X , False ),
433+ (- X @ - X , False ),
434+ (- X @ + X , True ),
435+ (- X @ ~ X , True ),
436+ # Same variable, self.star is True
437+ (+ X @ X , True ),
438+ (+ X @ - X , True ),
439+ (+ X @ + X , False ),
440+ (+ X @ ~ X , False ),
441+ # Same variable, self.star is True
442+ (~ X @ X , True ),
443+ (~ X @ - X , True ),
444+ (~ X @ + X , False ),
445+ (~ X @ ~ X , False ),
446+ ]:
447+ with self .subTest (expr = expr .to_y0 ()):
448+ self .assertIsInstance (expr , CounterfactualVariable )
449+ self .assertTrue (expr .is_event ())
450+ self .assertEqual (status , expr .is_inconsistent ())
451+
452+
359453class TestSafeConstructors (unittest .TestCase ):
360454 """Test that the .safe() constructors work properly."""
361455
0 commit comments