Skip to content

Commit 9ea7693

Browse files
committed
Rewrite the type parameters
1 parent 14c75ad commit 9ea7693

File tree

11 files changed

+259
-67
lines changed

11 files changed

+259
-67
lines changed

packages/cxx-gen-ast/src/new_ast_rewriter_cc.ts

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,43 @@ export function new_ast_rewriter_cc({
9595
`);
9696
}
9797

98+
let typeAttr: Member | undefined;
99+
98100
members.forEach((m) => {
99101
if (m === blockSymbol) return;
102+
if (m === typeAttr) return;
100103

101104
switch (m.kind) {
102105
case "node": {
103106
if (isBase(m.type)) {
104-
emit(` copy->${m.name} = ${visitor}(ast->${m.name});`);
107+
emit(`copy->${m.name} = ${visitor}(ast->${m.name});`);
108+
109+
switch (m.type) {
110+
case "DeclaratorAST":
111+
const specsAttr =
112+
members.find((m) => m.name == "typeSpecifierList")?.name ??
113+
members.find((m) => m.name == "declSpecifierList")?.name;
114+
if (specsAttr) {
115+
emit();
116+
emit(
117+
`auto ${m.name}Type = getDeclaratorType(translationUnit(), copy->${m.name}, ${specsAttr}Ctx.getType());`
118+
);
119+
120+
typeAttr = members.find(
121+
(m) => m.kind === "attribute" && m.name === "type"
122+
);
123+
124+
if (typeAttr) {
125+
emit(`copy->${typeAttr.name} = ${m.name}Type;`);
126+
}
127+
}
128+
break;
129+
default:
130+
break;
131+
} // switch
105132
} else {
106133
emit(
107-
` copy->${m.name} = ast_cast<${m.type}>(${visitor}(ast->${m.name}));`
134+
`copy->${m.name} = ast_cast<${m.type}>(${visitor}(ast->${m.name}));`
108135
);
109136
}
110137
break;

src/frontend/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

20-
aux_source_directory(cxx SOURCES)
2120

22-
add_executable(cxx ${SOURCES})
21+
add_executable(cxx
22+
cxx/ast_printer.cc
23+
cxx/frontend.cc
24+
cxx/verify_diagnostics_client.cc
25+
)
2326

2427
target_link_libraries(cxx PRIVATE cxx-lsp)
2528

src/parser/cxx/ast_rewriter.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,10 @@ auto ASTRewriter::operator()(TypeIdAST* ast) -> TypeIdAST* {
10911091
}
10921092

10931093
copy->declarator = operator()(ast->declarator);
1094-
copy->type = ast->type;
1094+
1095+
auto declaratorType = getDeclaratorType(translationUnit(), copy->declarator,
1096+
typeSpecifierListCtx.getType());
1097+
copy->type = declaratorType;
10951098

10961099
return copy;
10971100
}
@@ -1570,6 +1573,9 @@ auto ASTRewriter::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
15701573
}
15711574

15721575
copy->declarator = rewrite(ast->declarator);
1576+
1577+
auto declaratorType = getDeclaratorType(translationUnit(), copy->declarator,
1578+
declSpecifierListCtx.getType());
15731579
copy->requiresClause = rewrite(ast->requiresClause);
15741580
copy->functionBody = rewrite(ast->functionBody);
15751581
copy->symbol = ast->symbol;
@@ -1805,9 +1811,12 @@ auto ASTRewriter::DeclarationVisitor::operator()(ParameterDeclarationAST* ast)
18051811
}
18061812

18071813
copy->declarator = rewrite(ast->declarator);
1814+
1815+
auto declaratorType = getDeclaratorType(translationUnit(), copy->declarator,
1816+
typeSpecifierListCtx.getType());
1817+
copy->type = declaratorType;
18081818
copy->equalLoc = ast->equalLoc;
18091819
copy->expression = rewrite(ast->expression);
1810-
copy->type = ast->type;
18111820
copy->identifier = ast->identifier;
18121821
copy->isThisIntroduced = ast->isThisIntroduced;
18131822
copy->isPack = ast->isPack;
@@ -2886,6 +2895,9 @@ auto ASTRewriter::ExpressionVisitor::operator()(NewExpressionAST* ast)
28862895
}
28872896

28882897
copy->declarator = rewrite(ast->declarator);
2898+
2899+
auto declaratorType = getDeclaratorType(translationUnit(), copy->declarator,
2900+
typeSpecifierListCtx.getType());
28892901
copy->rparenLoc = ast->rparenLoc;
28902902
copy->newInitalizer = rewrite(ast->newInitalizer);
28912903

@@ -3072,6 +3084,9 @@ auto ASTRewriter::ExpressionVisitor::operator()(ConditionExpressionAST* ast)
30723084
}
30733085

30743086
copy->declarator = rewrite(ast->declarator);
3087+
3088+
auto declaratorType = getDeclaratorType(translationUnit(), copy->declarator,
3089+
declSpecifierListCtx.getType());
30753090
copy->initializer = rewrite(ast->initializer);
30763091
copy->symbol = ast->symbol;
30773092

@@ -4309,6 +4324,9 @@ auto ASTRewriter::ExceptionDeclarationVisitor::operator()(
43094324

43104325
copy->declarator = rewrite(ast->declarator);
43114326

4327+
auto declaratorType = getDeclaratorType(translationUnit(), copy->declarator,
4328+
typeSpecifierListCtx.getType());
4329+
43124330
return copy;
43134331
}
43144332

src/parser/cxx/decl_specs.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,21 @@ void DeclSpecs::Visitor::operator()(ComplexTypeSpecifierAST* ast) {
256256
void DeclSpecs::Visitor::operator()(NamedTypeSpecifierAST* ast) {
257257
specs.typeSpecifier = ast;
258258

259+
if (specs.rewriter) {
260+
auto typeParameter = symbol_cast<TypeParameterSymbol>(ast->symbol);
261+
const auto& args = specs.rewriter->templateArguments();
262+
263+
if (typeParameter && typeParameter->depth() == 0 &&
264+
typeParameter->index() < args.size()) {
265+
auto index = typeParameter->index();
266+
267+
if (auto ty = std::get_if<const Type*>(&args[index])) {
268+
specs.type = *ty;
269+
return;
270+
}
271+
}
272+
}
273+
259274
if (ast->symbol)
260275
specs.type = ast->symbol->type();
261276
else

src/parser/cxx/parser.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3813,6 +3813,8 @@ auto Parser::parse_alias_declaration(DeclarationAST*& yyast) -> bool {
38133813
ast->semicolonLoc = semicolonLoc;
38143814
ast->symbol = symbol;
38153815

3816+
ast->symbol->setDeclaration(ast);
3817+
38163818
return true;
38173819
}
38183820

src/parser/cxx/symbols.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,14 @@ void TypeAliasSymbol::setTemplateParameters(
493493
templateParameters_ = templateParameters;
494494
}
495495

496+
auto TypeAliasSymbol::declaration() const -> AliasDeclarationAST* {
497+
return declaration_;
498+
}
499+
500+
void TypeAliasSymbol::setDeclaration(AliasDeclarationAST* declaration) {
501+
declaration_ = declaration;
502+
}
503+
496504
VariableSymbol::VariableSymbol(Scope* enclosingScope)
497505
: Symbol(Kind, enclosingScope) {}
498506

src/parser/cxx/symbols.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,12 @@ class TypeAliasSymbol final : public Symbol {
505505
[[nodiscard]] auto templateParameters() const -> TemplateParametersSymbol*;
506506
void setTemplateParameters(TemplateParametersSymbol* templateParameters);
507507

508+
[[nodiscard]] auto declaration() const -> AliasDeclarationAST*;
509+
void setDeclaration(AliasDeclarationAST* declaration);
510+
508511
private:
509512
TemplateParametersSymbol* templateParameters_ = nullptr;
513+
AliasDeclarationAST* declaration_ = nullptr;
510514
};
511515

512516
class VariableSymbol final : public Symbol {

tests/api_tests/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ if (CMAKE_SYSTEM_NAME STREQUAL "WASI")
2222
return()
2323
endif()
2424

25-
aux_source_directory(. SOURCES)
26-
add_executable(test_api ${SOURCES})
25+
add_executable(test_api
26+
test_control.cc
27+
test_external_names.cc
28+
test_rewriter.cc
29+
test_substitution.cc
30+
test_type_printer.cc
31+
)
2732

2833
target_link_libraries(test_api
2934
GTest::gtest_main

tests/api_tests/test_rewriter.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) 2025 Roberto Raggi <[email protected]>
2+
//
3+
// Permission is hereby granted, free of charge, to any person obtaining a copy
4+
// of this software and associated documentation files (the "Software"), to deal
5+
// in the Software without restriction, including without limitation the rights
6+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
// copies of the Software, and to permit persons to whom the Software is
8+
// furnished to do so, subject to the following conditions:
9+
//
10+
// The above copyright notice and this permission notice shall be included in
11+
// all copies or substantial portions of the Software.
12+
//
13+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19+
// SOFTWARE.
20+
21+
#include <cxx/ast.h>
22+
#include <cxx/ast_rewriter.h>
23+
#include <cxx/control.h>
24+
#include <cxx/names.h>
25+
#include <cxx/scope.h>
26+
#include <cxx/symbol_instantiation.h>
27+
#include <cxx/symbols.h>
28+
#include <cxx/translation_unit.h>
29+
#include <cxx/type_checker.h>
30+
#include <cxx/types.h>
31+
#include <gtest/gtest.h>
32+
33+
#include <format>
34+
#include <iostream>
35+
#include <sstream>
36+
37+
#include "test_utils.h"
38+
39+
using namespace cxx;
40+
41+
template <typename Node>
42+
auto subst(Source& source, Node* ast, std::vector<TemplateArgument> args) {
43+
auto control = source.control();
44+
TypeChecker typeChecker(&source.unit);
45+
ASTRewriter rewrite{&typeChecker, args};
46+
return ast_cast<Node>(rewrite(ast));
47+
};
48+
49+
TEST(Rewriter, TypeAlias) {
50+
auto source = R"(
51+
template <typename T>
52+
using Ptr = const T*;
53+
54+
template <typename T, typename U>
55+
using Func = void(T, U);
56+
)"_cxx;
57+
58+
auto control = source.control();
59+
60+
auto ptrTypeAlias =
61+
subst(source, source.getAs<TypeAliasSymbol>("Ptr")->declaration(),
62+
{control->getIntType()});
63+
64+
ASSERT_EQ(to_string(ptrTypeAlias->typeId->type), "const int*");
65+
66+
auto funcTypeAlias =
67+
subst(source, source.getAs<TypeAliasSymbol>("Func")->declaration(),
68+
{control->getIntType(), control->getFloatType()});
69+
70+
ASSERT_EQ(to_string(funcTypeAlias->typeId->type), "void (int, float)");
71+
}

tests/api_tests/test_substitution.cc

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -32,65 +32,7 @@
3232
#include <iostream>
3333
#include <sstream>
3434

35-
namespace cxx {
36-
37-
auto dump_symbol(Symbol* symbol) -> std::string {
38-
std::ostringstream out;
39-
out << symbol;
40-
return out.str();
41-
}
42-
43-
struct Source {
44-
DiagnosticsClient diagnosticsClient;
45-
TranslationUnit unit{&diagnosticsClient};
46-
47-
explicit Source(std::string_view source) {
48-
unit.setSource(std::string(source), "<test>");
49-
unit.parse({
50-
.checkTypes = true,
51-
.fuzzyTemplateResolution = false,
52-
.reflect = true,
53-
});
54-
}
55-
56-
auto control() -> Control* { return unit.control(); }
57-
auto ast() -> UnitAST* { return unit.ast(); }
58-
auto scope() -> Scope* { return unit.globalScope(); }
59-
60-
auto get(std::string_view name) -> Symbol* {
61-
Symbol* symbol = nullptr;
62-
auto id = unit.control()->getIdentifier(name);
63-
for (auto candidate : scope()->find(id)) {
64-
if (symbol) return nullptr;
65-
symbol = candidate;
66-
}
67-
return symbol;
68-
}
69-
70-
auto instantiate(std::string_view name,
71-
const std::vector<TemplateArgument>& arguments) -> Symbol* {
72-
auto symbol = get(name);
73-
return control()->instantiate(&unit, symbol, arguments);
74-
}
75-
};
76-
77-
auto operator""_cxx(const char* source, std::size_t size) -> Source {
78-
return Source{std::string_view{source, size}};
79-
}
80-
81-
struct LookupMember {
82-
Source& source;
83-
84-
auto operator()(Scope* scope, std::string_view name) -> Symbol* {
85-
auto id = source.control()->getIdentifier(name);
86-
for (auto candidate : scope->find(id)) {
87-
return candidate;
88-
}
89-
return nullptr;
90-
}
91-
};
92-
93-
} // namespace cxx
35+
#include "test_utils.h"
9436

9537
using namespace cxx;
9638

0 commit comments

Comments
 (0)