Skip to content

Commit a2ddf4d

Browse files
committed
Improve substituion of non-type template parameters
1 parent b2cd1c2 commit a2ddf4d

File tree

7 files changed

+150
-49
lines changed

7 files changed

+150
-49
lines changed

src/parser/cxx/ast_interpreter.cc

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,10 @@ auto ASTInterpreter::ExpressionVisitor::operator()(IdExpressionAST* ast)
16341634
return enumerator->value();
16351635
}
16361636

1637+
if (auto var = symbol_cast<VariableSymbol>(ast->symbol)) {
1638+
return var->constValue();
1639+
}
1640+
16371641
return std::nullopt;
16381642
}
16391643

@@ -1989,33 +1993,49 @@ auto ASTInterpreter::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
19891993

19901994
case TokenKind::T_STAR:
19911995
if (control()->is_floating_point(ast->type))
1992-
return std::visit(ArithmeticCast<double>{}, *left) +
1996+
return std::visit(ArithmeticCast<double>{}, *left) *
19931997
std::visit(ArithmeticCast<double>{}, *right);
19941998
else if (control()->is_unsigned(ast->type))
1995-
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) +
1999+
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) *
19962000
std::visit(ArithmeticCast<std::uint64_t>{}, *right);
19972001
else
1998-
return std::visit(ArithmeticCast<std::int64_t>{}, *left) +
2002+
return std::visit(ArithmeticCast<std::int64_t>{}, *left) *
19992003
std::visit(ArithmeticCast<std::int64_t>{}, *right);
20002004

2001-
case TokenKind::T_SLASH:
2002-
if (control()->is_floating_point(ast->type))
2003-
return std::visit(ArithmeticCast<double>{}, *left) +
2004-
std::visit(ArithmeticCast<double>{}, *right);
2005-
else if (control()->is_unsigned(ast->type))
2006-
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) +
2007-
std::visit(ArithmeticCast<std::uint64_t>{}, *right);
2008-
else
2009-
return std::visit(ArithmeticCast<std::int64_t>{}, *left) +
2010-
std::visit(ArithmeticCast<std::int64_t>{}, *right);
2005+
case TokenKind::T_SLASH: {
2006+
if (control()->is_floating_point(ast->type)) {
2007+
auto l = std::visit(ArithmeticCast<double>{}, *left);
2008+
auto r = std::visit(ArithmeticCast<double>{}, *right);
2009+
if (r == 0.0) return std::nullopt;
2010+
return l / r;
2011+
}
20112012

2012-
case TokenKind::T_PERCENT:
2013-
if (control()->is_unsigned(ast->type))
2014-
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) %
2015-
std::visit(ArithmeticCast<std::uint64_t>{}, *right);
2016-
else
2017-
return std::visit(ArithmeticCast<std::int64_t>{}, *left) %
2018-
std::visit(ArithmeticCast<std::int64_t>{}, *right);
2013+
if (control()->is_unsigned(ast->type)) {
2014+
auto l = std::visit(ArithmeticCast<std::uint64_t>{}, *left);
2015+
auto r = std::visit(ArithmeticCast<std::uint64_t>{}, *right);
2016+
if (r == 0) return std::nullopt;
2017+
return l / r;
2018+
}
2019+
2020+
auto l = std::visit(ArithmeticCast<std::int64_t>{}, *left);
2021+
auto r = std::visit(ArithmeticCast<std::int64_t>{}, *right);
2022+
if (r == 0) return std::nullopt;
2023+
return l / r;
2024+
}
2025+
2026+
case TokenKind::T_PERCENT: {
2027+
if (control()->is_unsigned(ast->type)) {
2028+
auto l = std::visit(ArithmeticCast<std::uint64_t>{}, *left);
2029+
auto r = std::visit(ArithmeticCast<std::uint64_t>{}, *right);
2030+
if (r == 0) return std::nullopt;
2031+
return l % r;
2032+
}
2033+
2034+
auto l = std::visit(ArithmeticCast<std::int64_t>{}, *left);
2035+
auto r = std::visit(ArithmeticCast<std::int64_t>{}, *right);
2036+
if (r == 0) return std::nullopt;
2037+
return l % r;
2038+
}
20192039

20202040
case TokenKind::T_PLUS:
20212041
if (control()->is_floating_point(ast->type))

src/parser/cxx/ast_rewriter.cc

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,35 +2343,28 @@ auto ASTRewriter::ExpressionVisitor::operator()(NestedExpressionAST* ast)
23432343

23442344
auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
23452345
-> ExpressionAST* {
2346-
if (auto x = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
2347-
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
2348-
auto initializerPtr =
2349-
std::get_if<ExpressionAST*>(&rewrite.templateArguments_[x->index()]);
2350-
if (!initializerPtr) {
2351-
cxx_runtime_error("expected initializer for non-type template parameter");
2352-
}
2353-
2354-
auto initializer = rewrite(*initializerPtr);
2355-
2356-
if (auto eq = ast_cast<EqualInitializerAST>(initializer)) {
2357-
return eq->expression;
2358-
}
2359-
2360-
if (auto bracedInit = ast_cast<BracedInitListAST>(initializer)) {
2361-
if (bracedInit->expressionList && !bracedInit->expressionList->next) {
2362-
return bracedInit->expressionList->value;
2363-
}
2364-
}
2365-
}
2366-
23672346
auto copy = make_node<IdExpressionAST>(arena());
23682347

23692348
copy->valueCategory = ast->valueCategory;
23702349
copy->type = ast->type;
23712350
copy->nestedNameSpecifier = rewrite(ast->nestedNameSpecifier);
23722351
copy->templateLoc = ast->templateLoc;
23732352
copy->unqualifiedId = rewrite(ast->unqualifiedId);
2353+
23742354
copy->symbol = ast->symbol;
2355+
2356+
if (auto x = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
2357+
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
2358+
auto initializerPtr =
2359+
std::get_if<Symbol*>(&rewrite.templateArguments_[x->index()]);
2360+
if (!initializerPtr) {
2361+
cxx_runtime_error("expected initializer for non-type template parameter");
2362+
}
2363+
2364+
copy->symbol = *initializerPtr;
2365+
copy->type = copy->symbol->type();
2366+
}
2367+
23752368
copy->isTemplateIntroduced = ast->isTemplateIntroduced;
23762369

23772370
return copy;

src/parser/cxx/name_printer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct NamePrinter {
3636

3737
auto operator()(const ConstValue& value) const -> std::string { return {}; }
3838

39+
auto operator()(const Symbol* symbol) const -> std::string { return {}; }
40+
3941
auto operator()(ExpressionAST* value) const -> std::string { return {}; }
4042

4143
} template_argument_to_string;

src/parser/cxx/names_fwd.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ class Name;
4747
CXX_FOR_EACH_NAME(PROCESS_NAME)
4848
#undef PROCESS_NAME
4949

50-
using TemplateArgument = std::variant<const Type*, ConstValue, ExpressionAST*>;
50+
class Symbol;
51+
52+
using TemplateArgument =
53+
std::variant<const Type*, Symbol*, ConstValue, ExpressionAST*>;
5154

5255
enum class IdentifierInfoKind {
5356
kTypeTrait,

src/parser/cxx/symbols.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,14 @@ void VariableSymbol::setInitializer(ExpressionAST* initializer) {
563563
initializer_ = initializer;
564564
}
565565

566+
auto VariableSymbol::constValue() const -> const std::optional<ConstValue>& {
567+
return constValue_;
568+
}
569+
570+
void VariableSymbol::setConstValue(std::optional<ConstValue> value) {
571+
constValue_ = std::move(value);
572+
}
573+
566574
FieldSymbol::FieldSymbol(Scope* enclosingScope)
567575
: Symbol(Kind, enclosingScope) {}
568576

src/parser/cxx/symbols.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,10 +547,14 @@ class VariableSymbol final : public Symbol {
547547
[[nodiscard]] auto initializer() const -> ExpressionAST*;
548548
void setInitializer(ExpressionAST*);
549549

550+
[[nodiscard]] auto constValue() const -> const std::optional<ConstValue>&;
551+
void setConstValue(std::optional<ConstValue> value);
552+
550553
private:
551554
TemplateParametersSymbol* templateParameters_ = nullptr;
552555
TemplateDeclarationAST* templateDeclaration_ = nullptr;
553556
ExpressionAST* initializer_ = nullptr;
557+
std::optional<ConstValue> constValue_;
554558

555559
union {
556560
std::uint32_t flags_{};

tests/api_tests/test_rewriter.cc

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,39 +89,110 @@ using Func = void(T, const U&);
8989
TEST(Rewriter, Var) {
9090
auto source = R"(
9191
template <int i>
92-
const int c = i + 321;
92+
const int c = i + 321 + i;
9393
94-
constexpr int x = 123;
94+
constexpr int x = 123 * 2;
9595
96+
constexpr int y = c<123 * 2>;
9697
)"_cxx;
9798

99+
auto interp = ASTInterpreter{&source.unit};
100+
98101
auto control = source.control();
99102

100103
auto c = source.getAs<VariableSymbol>("c");
101104
ASSERT_TRUE(c != nullptr);
102105
auto templateDeclaration = c->templateDeclaration();
103106
ASSERT_TRUE(templateDeclaration != nullptr);
104107

108+
// extract the expression 123 * 2 from the AST
105109
auto x = source.getAs<VariableSymbol>("x");
106110
ASSERT_TRUE(x != nullptr);
107-
auto xinit = x->initializer();
111+
auto xinit = ast_cast<EqualInitializerAST>(x->initializer())->expression;
108112
ASSERT_TRUE(xinit != nullptr);
109113

114+
// synthesize const auto i = 123 * 2;
115+
116+
// ### need to set scope and location
117+
auto templArg = control->newVariableSymbol(nullptr, {});
118+
templArg->setInitializer(xinit);
119+
templArg->setType(control->add_const(x->type()));
120+
templArg->setConstValue(interp.evaluate(xinit));
121+
ASSERT_TRUE(templArg->constValue().has_value());
122+
110123
auto instance = subst(
111124
source, getTemplateBodyAs<SimpleDeclarationAST>(templateDeclaration),
112-
{xinit});
125+
{templArg});
113126

114127
auto decl = instance->initDeclaratorList->value;
115128
ASSERT_TRUE(decl != nullptr);
116129

117130
auto init = ast_cast<EqualInitializerAST>(decl->initializer);
118131
ASSERT_TRUE(init);
119132

120-
ASTInterpreter interp{&source.unit};
121-
122133
auto value = interp.evaluate(init->expression);
123134

124135
ASSERT_TRUE(value.has_value());
125136

126-
ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 123 + 321);
137+
ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 123 * 2 + 321 + 123 * 2);
138+
}
139+
140+
// simulate a template-id instantiation
141+
TEST(Rewriter, TemplateId) {
142+
auto source = R"(
143+
template <int i>
144+
const int c = i + 321 + i;
145+
146+
constexpr int y = c<123 * 2>;
147+
)"_cxx;
148+
149+
auto interp = ASTInterpreter{&source.unit};
150+
151+
auto control = source.control();
152+
153+
auto y = source.getAs<VariableSymbol>("y");
154+
ASSERT_TRUE(y != nullptr);
155+
auto yinit = ast_cast<EqualInitializerAST>(y->initializer())->expression;
156+
ASSERT_TRUE(yinit != nullptr);
157+
158+
auto idExpr = ast_cast<IdExpressionAST>(yinit);
159+
ASSERT_TRUE(idExpr != nullptr);
160+
161+
ASSERT_TRUE(idExpr->symbol);
162+
163+
auto templateId = ast_cast<SimpleTemplateIdAST>(idExpr->unqualifiedId);
164+
ASSERT_TRUE(templateId != nullptr);
165+
166+
// get the primary template declaration
167+
auto templateSym =
168+
symbol_cast<VariableSymbol>(templateId->primaryTemplateSymbol);
169+
ASSERT_TRUE(templateSym != nullptr);
170+
auto templateDecl = getTemplateBodyAs<SimpleDeclarationAST>(
171+
templateSym->templateDeclaration());
172+
ASSERT_TRUE(templateDecl != nullptr);
173+
174+
std::vector<TemplateArgument> templateArguments;
175+
for (auto arg : ListView{templateId->templateArgumentList}) {
176+
if (auto exprArg = ast_cast<ExpressionTemplateArgumentAST>(arg)) {
177+
auto expr = exprArg->expression;
178+
// ### need to set scope and location
179+
auto templArg = control->newVariableSymbol(nullptr, {});
180+
templArg->setInitializer(expr);
181+
templArg->setType(control->add_const(expr->type));
182+
templArg->setConstValue(interp.evaluate(expr));
183+
ASSERT_TRUE(templArg->constValue().has_value());
184+
templateArguments.push_back(templArg);
185+
}
186+
}
187+
188+
auto instance = subst(source, templateDecl, templateArguments);
189+
ASSERT_TRUE(instance != nullptr);
190+
191+
auto decl = instance->initDeclaratorList->value;
192+
ASSERT_TRUE(decl != nullptr);
193+
auto init = ast_cast<EqualInitializerAST>(decl->initializer);
194+
ASSERT_TRUE(init != nullptr);
195+
auto value = interp.evaluate(init->expression);
196+
ASSERT_TRUE(value.has_value());
197+
ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 123 * 2 + 321 + 123 * 2);
127198
}

0 commit comments

Comments
 (0)