Skip to content

Commit b4e5054

Browse files
authored
SQL AST: narrow result types with WHERE condition (#596)
1 parent 3a3ba02 commit b4e5054

File tree

5 files changed

+16918
-5455
lines changed

5 files changed

+16918
-5455
lines changed

src/SqlAst/ParserInference.php

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
use SqlFtw\Sql\Dml\TableReference\Join;
2020
use SqlFtw\Sql\Dml\TableReference\TableReferenceSubquery;
2121
use SqlFtw\Sql\Dml\TableReference\TableReferenceTable;
22+
use SqlFtw\Sql\Expression\Asterisk;
2223
use SqlFtw\Sql\Expression\Identifier;
24+
use SqlFtw\Sql\Expression\SimpleName;
2325
use SqlFtw\Sql\SqlSerializable;
2426
use staabm\PHPStanDba\QueryReflection\QueryReflection;
2527
use staabm\PHPStanDba\SchemaReflection\Join as SchemaJoin;
@@ -50,16 +52,18 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
5052
// returns a Generator. will not parse anything if you don't iterate over it
5153
$commands = $parser->parse($queryString);
5254

53-
$fromColumns = null;
55+
$selectColumns = null;
5456
$fromTable = null;
57+
$where = null;
5558
$joins = [];
5659
foreach ($commands as [$command]) {
5760
// Parser does not throw exceptions. this allows to parse partially invalid code and not fail on first error
5861
if ($command instanceof SelectCommand) {
59-
if (null === $fromColumns) {
60-
$fromColumns = $command->getColumns();
62+
if (null === $selectColumns) {
63+
$selectColumns = $command->getColumns();
6164
}
6265
$from = $command->getFrom();
66+
$where = $command->getWhere();
6367

6468
if (null === $from) {
6569
// no FROM clause, use an empty Table to signify this
@@ -128,13 +132,21 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
128132
// not parsable atm, return un-narrowed type
129133
return $resultType;
130134
}
131-
if (null === $fromColumns) {
135+
if (null === $selectColumns) {
132136
throw new ShouldNotHappenException();
133137
}
134138

135-
$queryScope = new QueryScope($fromTable, $joins);
139+
$queryScope = new QueryScope($fromTable, $joins, $where);
136140

137-
foreach ($fromColumns as $i => $column) {
141+
// If we're selecting '*', get the selected columns from the table
142+
if (\count($selectColumns) === 1 && $selectColumns[0]->getExpression() instanceof Asterisk) {
143+
$selectColumns = [];
144+
foreach ($fromTable->getColumns() as $column) {
145+
$selectColumns[] = new SelectExpression(new SimpleName($column->getName()));
146+
}
147+
}
148+
149+
foreach ($selectColumns as $i => $column) {
138150
$expression = $column->getExpression();
139151

140152
$offsetType = new ConstantIntegerType($i);

src/SqlAst/QueryScope.php

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use PHPStan\Type\NullType;
1313
use PHPStan\Type\Type;
1414
use PHPStan\Type\TypeCombinator;
15+
use SqlFtw\Sql\Expression\BinaryOperator;
1516
use SqlFtw\Sql\Expression\BoolValue;
1617
use SqlFtw\Sql\Expression\CaseExpression;
1718
use SqlFtw\Sql\Expression\ComparisonOperator;
@@ -26,6 +27,7 @@
2627
use SqlFtw\Sql\Expression\Parentheses;
2728
use SqlFtw\Sql\Expression\SimpleName;
2829
use SqlFtw\Sql\Expression\StringValue;
30+
use SqlFtw\Sql\SqlSerializable;
2931
use staabm\PHPStanDba\SchemaReflection\Column;
3032
use staabm\PHPStanDba\SchemaReflection\Join;
3133
use staabm\PHPStanDba\SchemaReflection\Table;
@@ -47,13 +49,19 @@ final class QueryScope
4749
*/
4850
private $joinedTables;
4951

52+
/**
53+
* @var ?SqlSerializable
54+
*/
55+
private $whereCondition;
56+
5057
/**
5158
* @param list<Join> $joinedTables
5259
*/
53-
public function __construct(Table $fromTable, array $joinedTables)
60+
public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializable $whereCondition)
5461
{
5562
$this->fromTable = $fromTable;
5663
$this->joinedTables = $joinedTables;
64+
$this->whereCondition = $whereCondition;
5765

5866
$this->extensions = [
5967
new PositiveIntReturnTypeExtension(),
@@ -78,6 +86,24 @@ public function __construct(Table $fromTable, array $joinedTables)
7886
* @param Identifier|Literal|ExpressionNode $expression
7987
*/
8088
public function getType($expression): Type
89+
{
90+
$resultType = $this->resolveExpression($expression);
91+
92+
if ($this->whereCondition !== null && $expression instanceof SimpleName) {
93+
$resultType = $this->narrowWhereCondition(
94+
$this->whereCondition,
95+
$expression->getName(),
96+
$resultType
97+
);
98+
}
99+
100+
return $resultType;
101+
}
102+
103+
/**
104+
* @param Identifier|Literal|ExpressionNode $expression
105+
*/
106+
private function resolveExpression($expression): Type
81107
{
82108
if ($expression instanceof NullLiteral) {
83109
return new NullType();
@@ -219,4 +245,53 @@ private function narrowJoinCondition(Column $column, Join $join): ?Type
219245

220246
return null;
221247
}
248+
249+
private function narrowWhereCondition(SqlSerializable $op, string $name, Type $valueType): Type
250+
{
251+
// If condition is in parentheses, try again with its contents
252+
if ($op instanceof Parentheses) {
253+
return $this->narrowWhereCondition($op->getContents(), $name, $valueType);
254+
}
255+
256+
// Only binary ops are currently supported
257+
if (! ($op instanceof BinaryOperator)) {
258+
return $valueType;
259+
}
260+
261+
$left = $op->getLeft();
262+
$right = $op->getRight();
263+
$operator = $op->getOperator()->getValue();
264+
265+
// Handle compound conditions
266+
if ($operator === Operator::AND) {
267+
if ($left instanceof BinaryOperator) {
268+
$valueType = $this->narrowWhereCondition($left, $name, $valueType);
269+
}
270+
if ($right instanceof BinaryOperator) {
271+
$valueType = $this->narrowWhereCondition($right, $name, $valueType);
272+
}
273+
return $valueType;
274+
}
275+
276+
// Only simple names are currently supported
277+
if (! ($left instanceof SimpleName)) {
278+
return $valueType;
279+
}
280+
if ($left->getName() !== $name) {
281+
return $valueType;
282+
}
283+
284+
// Handle NULL comparisons
285+
if ($right instanceof NullLiteral) {
286+
if ($operator === Operator::IS_NOT) {
287+
return TypeCombinator::removeNull($valueType);
288+
}
289+
if ($operator === Operator::IS) {
290+
return TypeCombinator::intersect($valueType, new NullType());
291+
}
292+
}
293+
294+
// Unsupported operator
295+
return $valueType;
296+
}
222297
}

0 commit comments

Comments
 (0)