Skip to content

Commit dfa0a9f

Browse files
committed
Simulate template class instantiation
1 parent debe239 commit dfa0a9f

File tree

8 files changed

+131
-4
lines changed

8 files changed

+131
-4
lines changed

src/parser/cxx/ast_rewriter.cc

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cxx/control.h>
2828
#include <cxx/decl.h>
2929
#include <cxx/decl_specs.h>
30+
#include <cxx/scope.h>
3031
#include <cxx/symbols.h>
3132
#include <cxx/translation_unit.h>
3233
#include <cxx/type_checker.h>
@@ -1055,9 +1056,24 @@ auto ASTRewriter::operator()(InitDeclaratorAST* ast, const DeclSpecs& declSpecs)
10551056
auto copy = make_node<InitDeclaratorAST>(arena());
10561057

10571058
copy->declarator = operator()(ast->declarator);
1059+
1060+
auto decl = Decl{declSpecs, copy->declarator};
1061+
1062+
auto type = getDeclaratorType(translationUnit(), copy->declarator,
1063+
declSpecs.getType());
1064+
1065+
// ### fix scope
1066+
if (binder_.scope() && binder_.scope()->isClassScope()) {
1067+
auto symbol = binder_.declareMemberSymbol(ast->declarator, decl);
1068+
copy->symbol = symbol;
1069+
} else {
1070+
// ### TODO
1071+
copy->symbol = ast->symbol;
1072+
}
1073+
10581074
copy->requiresClause = operator()(ast->requiresClause);
10591075
copy->initializer = operator()(ast->initializer);
1060-
copy->symbol = ast->symbol;
1076+
// copy->symbol = ast->symbol; // TODO remove, done above
10611077

10621078
return copy;
10631079
}
@@ -3751,6 +3767,18 @@ auto ASTRewriter::SpecifierVisitor::operator()(ClassSpecifierAST* ast)
37513767
copy->finalLoc = ast->finalLoc;
37523768
copy->colonLoc = ast->colonLoc;
37533769

3770+
// ### TODO: use Binder::bind()
3771+
auto _ = Binder::ScopeGuard{binder()};
3772+
auto location = ast->symbol->location();
3773+
auto className = ast->symbol->name();
3774+
auto classSymbol = control()->newClassSymbol(binder()->scope(), location);
3775+
classSymbol->setName(className);
3776+
classSymbol->setIsUnion(ast->symbol->isUnion());
3777+
classSymbol->setFinal(ast->isFinal);
3778+
binder()->setScope(classSymbol);
3779+
3780+
copy->symbol = classSymbol;
3781+
37543782
for (auto baseSpecifierList = &copy->baseSpecifierList;
37553783
auto node : ListView{ast->baseSpecifierList}) {
37563784
auto value = rewrite(node);
@@ -3769,7 +3797,7 @@ auto ASTRewriter::SpecifierVisitor::operator()(ClassSpecifierAST* ast)
37693797

37703798
copy->rbraceLoc = ast->rbraceLoc;
37713799
copy->classKey = ast->classKey;
3772-
copy->symbol = ast->symbol;
3800+
// copy->symbol = ast->symbol; // TODO: remove done by the binder
37733801
copy->isFinal = ast->isFinal;
37743802

37753803
return copy;

src/parser/cxx/binder.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ void Binder::bind(ElaboratedTypeSpecifierAST* ast, DeclSpecs& declSpecs) {
161161
classSymbol->setIsUnion(isUnion);
162162
classSymbol->setName(className);
163163
classSymbol->setTemplateParameters(currentTemplateParameters());
164+
classSymbol->setTemplateDeclaration(declSpecs.templateHead);
164165
declaringScope()->addSymbol(classSymbol);
166+
167+
classSymbol->setDeclaration(ast);
165168
}
166169

167170
ast->symbol = classSymbol;
@@ -234,6 +237,12 @@ void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
234237
}
235238
}
236239

240+
classSymbol->setDeclaration(ast);
241+
242+
if (declSpecs.templateHead) {
243+
classSymbol->setTemplateDeclaration(declSpecs.templateHead);
244+
}
245+
237246
classSymbol->setFinal(ast->isFinal);
238247

239248
ast->symbol = classSymbol;

src/parser/cxx/decl_specs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class DeclSpecs {
4949

5050
ASTRewriter* rewriter = nullptr;
5151
TranslationUnit* unit = nullptr;
52+
TemplateDeclarationAST* templateHead = nullptr;
5253
const Type* type = nullptr;
5354
SpecifierAST* typeSpecifier = nullptr;
5455

src/parser/cxx/parser.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4024,6 +4024,7 @@ auto Parser::parse_simple_declaration(DeclarationAST*& yyast,
40244024
if (parse_notypespec_function_definition(yyast, attributes, ctx)) return true;
40254025

40264026
DeclSpecs specs{unit};
4027+
specs.templateHead = templateHead;
40274028
List<SpecifierAST*>* declSpecifierList = nullptr;
40284029

40294030
auto lookat_decl_specifiers = [&] {

src/parser/cxx/symbols.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,21 @@ auto ClassSymbol::hasBaseClass(
215215
return false;
216216
}
217217

218+
auto ClassSymbol::declaration() const -> SpecifierAST* { return specifier_; }
219+
220+
void ClassSymbol::setDeclaration(SpecifierAST* specifier) {
221+
specifier_ = specifier;
222+
}
223+
224+
auto ClassSymbol::templateDeclaration() const -> TemplateDeclarationAST* {
225+
return templateDeclaration_;
226+
}
227+
228+
void ClassSymbol::setTemplateDeclaration(
229+
TemplateDeclarationAST* templateDeclaration) {
230+
templateDeclaration_ = templateDeclaration;
231+
}
232+
218233
auto ClassSymbol::templateParameters() const -> TemplateParametersSymbol* {
219234
return templateInfo_ ? templateInfo_->templateParameters() : nullptr;
220235
}

src/parser/cxx/symbols.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class Symbol {
136136
[[nodiscard]] auto next() const -> Symbol*;
137137

138138
#define PROCESS_SYMBOL(S) \
139-
[[nodiscard]] auto is##S() const->bool { return kind_ == SymbolKind::k##S; }
139+
[[nodiscard]] auto is##S() const -> bool { return kind_ == SymbolKind::k##S; }
140140
CXX_FOR_EACH_SYMBOL(PROCESS_SYMBOL)
141141
#undef PROCESS_SYMBOL
142142

@@ -263,6 +263,12 @@ class ClassSymbol final : public ScopedSymbol {
263263
[[nodiscard]] auto flags() const -> std::uint32_t;
264264
void setFlags(std::uint32_t flags);
265265

266+
[[nodiscard]] auto declaration() const -> SpecifierAST*;
267+
void setDeclaration(SpecifierAST* ast);
268+
269+
[[nodiscard]] auto templateDeclaration() const -> TemplateDeclarationAST*;
270+
void setTemplateDeclaration(TemplateDeclarationAST* templateDeclaration);
271+
266272
[[nodiscard]] auto templateParameters() const -> TemplateParametersSymbol*;
267273
void setTemplateParameters(TemplateParametersSymbol* templateParameters);
268274

@@ -311,6 +317,8 @@ class ClassSymbol final : public ScopedSymbol {
311317
std::vector<BaseClassSymbol*> baseClasses_;
312318
std::vector<FunctionSymbol*> constructors_;
313319
std::unique_ptr<TemplateInfo<ClassSymbol>> templateInfo_;
320+
SpecifierAST* specifier_ = nullptr;
321+
TemplateDeclarationAST* templateDeclaration_ = nullptr;
314322
ClassSymbol* templateClass_ = nullptr;
315323
std::size_t templateSepcializationIndex_ = 0;
316324
int sizeInBytes_ = 0;
@@ -784,7 +792,7 @@ auto visit(Visitor&& visitor, Symbol* symbol) {
784792
}
785793

786794
#define PROCESS_SYMBOL(S) \
787-
inline auto is##S##Symbol(Symbol* symbol)->bool { \
795+
inline auto is##S##Symbol(Symbol* symbol) -> bool { \
788796
return symbol && symbol->kind() == SymbolKind::k##S; \
789797
}
790798

tests/api_tests/test_rewriter.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,62 @@ const auto N = S<0, 1, 2>;
255255
const auto value = interp.evaluate(eq->expression);
256256
ASSERT_TRUE(value.has_value());
257257
ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 0 * 0 + 1 * 1 + 2 * 2);
258+
}
259+
260+
TEST(Rewriter, Class) {
261+
auto source = R"(
262+
template <typename T1, typename T2>
263+
struct Pair {
264+
T1 first;
265+
T2 second;
266+
auto operator=(const Pair& other) -> Pair&;
267+
};
268+
269+
using Pair1 = Pair<int, float*>;
270+
271+
)"_cxx_no_templates;
272+
273+
auto control = source.control();
274+
275+
auto pair = source.getAs<ClassSymbol>("Pair");
276+
ASSERT_TRUE(pair != nullptr);
277+
278+
ASSERT_TRUE(pair->declaration() != nullptr);
279+
280+
auto classDecl = ast_cast<ClassSpecifierAST>(pair->declaration());
281+
ASSERT_TRUE(classDecl != nullptr);
282+
283+
auto templateDecl = pair->templateDeclaration();
284+
ASSERT_TRUE(templateDecl != nullptr);
285+
286+
auto pair1 = source.getAs<TypeAliasSymbol>("Pair1");
287+
ASSERT_TRUE(pair1 != nullptr);
288+
289+
auto pair1Type = type_cast<UnresolvedNameType>(pair1->type());
290+
ASSERT_TRUE(pair1Type != nullptr);
291+
292+
auto templateId = ast_cast<SimpleTemplateIdAST>(pair1Type->unqualifiedId());
293+
ASSERT_TRUE(templateId != nullptr);
294+
295+
auto templateArguments = make_substitution(&source.unit, templateDecl,
296+
templateId->templateArgumentList);
297+
298+
auto instance = substitute(source, classDecl, templateArguments);
299+
ASSERT_TRUE(instance != nullptr);
300+
301+
auto classDeclInstance = ast_cast<ClassSpecifierAST>(instance);
302+
ASSERT_TRUE(classDeclInstance != nullptr);
303+
304+
auto classInstance = classDeclInstance->symbol;
305+
306+
ASSERT_TRUE(classInstance != nullptr);
307+
308+
std::ostringstream os;
309+
dump(os, classDeclInstance->symbol);
310+
311+
ASSERT_EQ(os.str(), R"(class Pair
312+
field int first
313+
field float* second
314+
function ::Pair& operator =(const ::Pair&)
315+
)");
258316
}

tests/api_tests/test_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <cxx/ast.h>
2424
#include <cxx/control.h>
25+
#include <cxx/memory_layout.h>
2526
#include <cxx/names.h>
2627
#include <cxx/scope.h>
2728
#include <cxx/symbols.h>
@@ -40,10 +41,16 @@ inline auto dump_symbol(Symbol* symbol) -> std::string {
4041
}
4142

4243
struct Source {
44+
std::unique_ptr<MemoryLayout> memoryLayout;
4345
DiagnosticsClient diagnosticsClient;
4446
TranslationUnit unit{&diagnosticsClient};
4547

4648
explicit Source(std::string_view source, bool templateInstantiation = true) {
49+
// default to wasm32 memory layout
50+
auto memoryLayout = std::make_unique<MemoryLayout>(32);
51+
52+
unit.control()->setMemoryLayout(memoryLayout.get());
53+
4754
unit.setSource(std::string(source), "<test>");
4855

4956
unit.parse({

0 commit comments

Comments
 (0)