Skip to content

Commit e01951c

Browse files
authored
SQL AST: fix aggregate functions with GROUP BY (#599)
1 parent b4e5054 commit e01951c

File tree

8 files changed

+4428
-1325
lines changed

8 files changed

+4428
-1325
lines changed

src/SqlAst/AvgReturnTypeExtension.php

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515

1616
final class AvgReturnTypeExtension implements QueryFunctionReturnTypeExtension
1717
{
18+
/**
19+
* @var bool
20+
*/
21+
private $hasGroupBy;
22+
23+
public function __construct(bool $hasGroupBy)
24+
{
25+
$this->hasGroupBy = $hasGroupBy;
26+
}
27+
1828
public function isFunctionSupported(FunctionCall $expression): bool
1929
{
2030
return \in_array($expression->getFunction()->getName(), [BuiltInFunction::AVG], true);
@@ -33,6 +43,7 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
3343
if ($argType->isNull()->yes()) {
3444
return $argType;
3545
}
46+
$containsNull = TypeCombinator::containsNull($argType);
3647
$argType = TypeCombinator::removeNull($argType);
3748

3849
if ($argType instanceof UnionType) {
@@ -45,7 +56,10 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
4556
$newType = $this->convertToAvgType($argType);
4657
}
4758

48-
return TypeCombinator::addNull($newType);
59+
if ($containsNull || ! $this->hasGroupBy) {
60+
$newType = TypeCombinator::addNull($newType);
61+
}
62+
return $newType;
4963
}
5064

5165
private function convertToAvgType(Type $argType): Type

src/SqlAst/MinMaxReturnTypeExtension.php

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111

1212
final class MinMaxReturnTypeExtension implements QueryFunctionReturnTypeExtension
1313
{
14+
/**
15+
* @var bool
16+
*/
17+
private $hasGroupBy;
18+
19+
public function __construct(bool $hasGroupBy)
20+
{
21+
$this->hasGroupBy = $hasGroupBy;
22+
}
23+
1424
public function isFunctionSupported(FunctionCall $expression): bool
1525
{
1626
return \in_array($expression->getFunction()->getName(), [BuiltInFunction::MIN, BuiltInFunction::MAX], true);
@@ -26,6 +36,9 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
2636

2737
$argType = $scope->getType($args[0]);
2838

29-
return TypeCombinator::addNull($argType);
39+
if (! $this->hasGroupBy) {
40+
$argType = TypeCombinator::addNull($argType);
41+
}
42+
return $argType;
3043
}
3144
}

src/SqlAst/ParserInference.php

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
5555
$selectColumns = null;
5656
$fromTable = null;
5757
$where = null;
58+
$groupBy = null;
5859
$joins = [];
5960
foreach ($commands as [$command]) {
6061
// Parser does not throw exceptions. this allows to parse partially invalid code and not fail on first error
@@ -64,6 +65,7 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
6465
}
6566
$from = $command->getFrom();
6667
$where = $command->getWhere();
68+
$groupBy = $command->getGroupBy();
6769

6870
if (null === $from) {
6971
// no FROM clause, use an empty Table to signify this
@@ -136,7 +138,7 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
136138
throw new ShouldNotHappenException();
137139
}
138140

139-
$queryScope = new QueryScope($fromTable, $joins, $where);
141+
$queryScope = new QueryScope($fromTable, $joins, $where, $groupBy !== null);
140142

141143
// If we're selecting '*', get the selected columns from the table
142144
if (\count($selectColumns) === 1 && $selectColumns[0]->getExpression() instanceof Asterisk) {

src/SqlAst/QueryScope.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ final class QueryScope
5757
/**
5858
* @param list<Join> $joinedTables
5959
*/
60-
public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializable $whereCondition)
60+
public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializable $whereCondition, bool $hasGroupBy)
6161
{
6262
$this->fromTable = $fromTable;
6363
$this->joinedTables = $joinedTables;
@@ -73,12 +73,12 @@ public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializa
7373
new InstrReturnTypeExtension(),
7474
new StrCaseReturnTypeExtension(),
7575
new ReplaceReturnTypeExtension(),
76-
new AvgReturnTypeExtension(),
77-
new SumReturnTypeExtension(),
76+
new AvgReturnTypeExtension($hasGroupBy),
77+
new SumReturnTypeExtension($hasGroupBy),
7878
new IsNullReturnTypeExtension(),
7979
new AbsReturnTypeExtension(),
8080
new RoundReturnTypeExtension(),
81-
new MinMaxReturnTypeExtension(),
81+
new MinMaxReturnTypeExtension($hasGroupBy),
8282
];
8383
}
8484

src/SqlAst/SumReturnTypeExtension.php

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414

1515
final class SumReturnTypeExtension implements QueryFunctionReturnTypeExtension
1616
{
17+
/**
18+
* @var bool
19+
*/
20+
private $hasGroupBy;
21+
22+
public function __construct(bool $hasGroupBy)
23+
{
24+
$this->hasGroupBy = $hasGroupBy;
25+
}
26+
1727
public function isFunctionSupported(FunctionCall $expression): bool
1828
{
1929
return \in_array($expression->getFunction()->getName(), [BuiltInFunction::SUM], true);
@@ -43,10 +53,9 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
4353
$result = $argType;
4454
}
4555

46-
if ($containsNull) {
56+
if ($containsNull || ! $this->hasGroupBy) {
4757
$result = TypeCombinator::addNull($result);
4858
}
49-
5059
return $result;
5160
}
5261
}

0 commit comments

Comments
 (0)