1212use PHPStan \Type \NullType ;
1313use PHPStan \Type \Type ;
1414use PHPStan \Type \TypeCombinator ;
15+ use SqlFtw \Sql \Expression \BinaryOperator ;
1516use SqlFtw \Sql \Expression \BoolValue ;
1617use SqlFtw \Sql \Expression \CaseExpression ;
1718use SqlFtw \Sql \Expression \ComparisonOperator ;
2627use SqlFtw \Sql \Expression \Parentheses ;
2728use SqlFtw \Sql \Expression \SimpleName ;
2829use SqlFtw \Sql \Expression \StringValue ;
30+ use SqlFtw \Sql \SqlSerializable ;
2931use staabm \PHPStanDba \SchemaReflection \Column ;
3032use staabm \PHPStanDba \SchemaReflection \Join ;
3133use staabm \PHPStanDba \SchemaReflection \Table ;
@@ -47,13 +49,19 @@ final class QueryScope
4749 */
4850 private $ joinedTables ;
4951
52+ /**
53+ * @var ?SqlSerializable
54+ */
55+ private $ whereCondition ;
56+
5057 /**
5158 * @param list<Join> $joinedTables
5259 */
53- public function __construct (Table $ fromTable , array $ joinedTables )
60+ public function __construct (Table $ fromTable , array $ joinedTables, ? SqlSerializable $ whereCondition )
5461 {
5562 $ this ->fromTable = $ fromTable ;
5663 $ this ->joinedTables = $ joinedTables ;
64+ $ this ->whereCondition = $ whereCondition ;
5765
5866 $ this ->extensions = [
5967 new PositiveIntReturnTypeExtension (),
@@ -78,6 +86,24 @@ public function __construct(Table $fromTable, array $joinedTables)
7886 * @param Identifier|Literal|ExpressionNode $expression
7987 */
8088 public function getType ($ expression ): Type
89+ {
90+ $ resultType = $ this ->resolveExpression ($ expression );
91+
92+ if ($ this ->whereCondition !== null && $ expression instanceof SimpleName) {
93+ $ resultType = $ this ->narrowWhereCondition (
94+ $ this ->whereCondition ,
95+ $ expression ->getName (),
96+ $ resultType
97+ );
98+ }
99+
100+ return $ resultType ;
101+ }
102+
103+ /**
104+ * @param Identifier|Literal|ExpressionNode $expression
105+ */
106+ private function resolveExpression ($ expression ): Type
81107 {
82108 if ($ expression instanceof NullLiteral) {
83109 return new NullType ();
@@ -219,4 +245,53 @@ private function narrowJoinCondition(Column $column, Join $join): ?Type
219245
220246 return null ;
221247 }
248+
249+ private function narrowWhereCondition (SqlSerializable $ op , string $ name , Type $ valueType ): Type
250+ {
251+ // If condition is in parentheses, try again with its contents
252+ if ($ op instanceof Parentheses) {
253+ return $ this ->narrowWhereCondition ($ op ->getContents (), $ name , $ valueType );
254+ }
255+
256+ // Only binary ops are currently supported
257+ if (! ($ op instanceof BinaryOperator)) {
258+ return $ valueType ;
259+ }
260+
261+ $ left = $ op ->getLeft ();
262+ $ right = $ op ->getRight ();
263+ $ operator = $ op ->getOperator ()->getValue ();
264+
265+ // Handle compound conditions
266+ if ($ operator === Operator::AND ) {
267+ if ($ left instanceof BinaryOperator) {
268+ $ valueType = $ this ->narrowWhereCondition ($ left , $ name , $ valueType );
269+ }
270+ if ($ right instanceof BinaryOperator) {
271+ $ valueType = $ this ->narrowWhereCondition ($ right , $ name , $ valueType );
272+ }
273+ return $ valueType ;
274+ }
275+
276+ // Only simple names are currently supported
277+ if (! ($ left instanceof SimpleName)) {
278+ return $ valueType ;
279+ }
280+ if ($ left ->getName () !== $ name ) {
281+ return $ valueType ;
282+ }
283+
284+ // Handle NULL comparisons
285+ if ($ right instanceof NullLiteral) {
286+ if ($ operator === Operator::IS_NOT ) {
287+ return TypeCombinator::removeNull ($ valueType );
288+ }
289+ if ($ operator === Operator::IS ) {
290+ return TypeCombinator::intersect ($ valueType , new NullType ());
291+ }
292+ }
293+
294+ // Unsupported operator
295+ return $ valueType ;
296+ }
222297}
0 commit comments