diff --git a/src/parser/cxx/ast_interpreter.cc b/src/parser/cxx/ast_interpreter.cc index 34fdae59..002dfcbb 100644 --- a/src/parser/cxx/ast_interpreter.cc +++ b/src/parser/cxx/ast_interpreter.cc @@ -34,6 +34,34 @@ namespace cxx { namespace { +struct ToInt { + auto operator()(bool v) const -> std::optional { + return v ? 1 : 0; + } + + auto operator()(std::intmax_t v) const -> std::optional { + return v; + } + + auto operator()(auto x) const -> std::optional { + return std::nullopt; // Unsupported type for int conversion + } +}; + +struct ToUInt { + auto operator()(bool v) const -> std::optional { + return v ? 1 : 0; + } + + auto operator()(std::intmax_t v) const -> std::optional { + return std::bit_cast(v); + } + + auto operator()(auto x) const -> std::optional { + return std::nullopt; + } +}; + template struct ArithmeticCast { auto operator()(const StringLiteral*) const -> T { @@ -285,6 +313,10 @@ struct ASTInterpreter::ExpressionVisitor { [[nodiscard]] auto control() -> Control* { return accept.control(); } + [[nodiscard]] auto memoryLayout() -> MemoryLayout* { + return control()->memoryLayout(); + } + [[nodiscard]] auto evaluate(ExpressionAST* ast) -> ExpressionResult { return accept(ast); } @@ -297,10 +329,26 @@ struct ASTInterpreter::ExpressionVisitor { return accept.toInt(value).value_or(0); } + [[nodiscard]] auto toInt32(const ConstValue& value) -> std::int32_t { + return static_cast(toInt(value)); + } + + [[nodiscard]] auto toInt64(const ConstValue& value) -> std::int64_t { + return static_cast(toInt(value)); + } + [[nodiscard]] auto toUInt(const ConstValue& value) -> std::uintmax_t { return accept.toUInt(value).value_or(0); } + [[nodiscard]] auto toUInt32(const ConstValue& value) -> std::uint32_t { + return static_cast(toUInt(value)); + } + + [[nodiscard]] auto toUInt64(const ConstValue& value) -> std::uint64_t { + return static_cast(toUInt(value)); + } + [[nodiscard]] auto toFloat(const ConstValue& value) -> float { return accept.toFloat(value).value_or(0.0f); } @@ -309,99 +357,162 @@ struct ASTInterpreter::ExpressionVisitor { return accept.toDouble(value).value_or(0.0); } + [[nodiscard]] auto toValue(std::uintmax_t value) -> ConstValue { + return ConstValue(std::bit_cast(value)); + } + auto star_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) { + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) { return toDouble(*left) * toDouble(*right); } - if (control()->is_unsigned(ast->type)) { - return toUInt(*left) * toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toValue(toUInt32(*left) * toUInt32(*right)); + return toValue(toUInt64(*left) * toUInt64(*right)); } - return toInt(*left) * toInt(*right); + if (sz <= 4) return toValue(toInt32(*left) * toInt32(*right)); + return toValue(toInt64(*left) * toInt64(*right)); } auto slash_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) { + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) { auto l = toDouble(*left); auto r = toDouble(*right); if (r == 0.0) return std::nullopt; return l / r; } - if (control()->is_unsigned(ast->type)) { - auto l = toUInt(*left); - auto r = toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) { + auto l = toUInt32(*left); + auto r = toUInt32(*right); + if (r == 0) return std::nullopt; + return toValue(l / r); + } + + auto l = toUInt64(*left); + auto r = toUInt64(*right); if (r == 0) return std::nullopt; - return l / r; + return toValue(l / r); } - auto l = toInt(*left); - auto r = toInt(*right); + if (sz <= 4) { + auto l = toInt32(*left); + auto r = toInt32(*right); + if (r == 0) return std::nullopt; + return toValue(l / r); + } + + auto l = toInt64(*left); + auto r = toInt64(*right); if (r == 0) return std::nullopt; - return l / r; + return toValue(l / r); } auto percent_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_unsigned(ast->type)) { - auto l = toUInt(*left); - auto r = toUInt(*right); + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_unsigned(type)) { + if (sz <= 4) { + auto l = toUInt32(*left); + auto r = toUInt32(*right); + if (r == 0) return std::nullopt; + return toValue(l % r); + } + + auto l = toUInt64(*left); + auto r = toUInt64(*right); if (r == 0) return std::nullopt; - return l % r; + return toValue(l % r); } - auto l = toInt(*left); - auto r = toInt(*right); + if (sz <= 4) { + auto l = toInt32(*left); + auto r = toInt32(*right); + if (r == 0) return std::nullopt; + return toValue(l % r); + } + + auto l = toInt64(*left); + auto r = toInt64(*right); if (r == 0) return std::nullopt; - return l % r; + return toValue(l % r); } auto plus_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) { + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) { return toDouble(*left) + toDouble(*right); } - if (control()->is_unsigned(ast->type)) { - return toUInt(*left) + toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toValue(toUInt32(*left) + toUInt32(*right)); + return toValue(toUInt64(*left) + toUInt64(*right)); } - return toInt(*left) + toInt(*right); + if (sz <= 4) return toValue(toInt32(*left) + toInt32(*right)); + return toValue(toInt64(*left) + toInt64(*right)); } auto minus_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) { + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) { return toDouble(*left) - toDouble(*right); } - if (control()->is_unsigned(ast->type)) { - return toUInt(*left) - toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toValue(toUInt32(*left) - toUInt32(*right)); + return toValue(toUInt64(*left) - toUInt64(*right)); } - return toInt(*left) - toInt(*right); + if (sz <= 4) return toValue(toInt32(*left) - toInt32(*right)); + return toValue(toInt64(*left) - toInt64(*right)); } auto less_less_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_unsigned(ast->type)) { - return toUInt(*left) << toUInt(*right); + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_unsigned(type)) { + if (sz <= 4) return toValue(toUInt32(*left) << toUInt32(*right)); + return toValue(toUInt64(*left) << toUInt64(*right)); } - return toInt(*left) << toInt(*right); + if (sz <= 4) return toValue(toInt32(*left) << toInt32(*right)); + return toValue(toInt64(*left) << toInt64(*right)); } auto greater_greater_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_unsigned(ast->type)) { - return toUInt(*left) >> toUInt(*right); + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_unsigned(type)) { + if (sz <= 4) return toValue(toUInt32(*left) >> toUInt32(*right)); + return toValue(toUInt64(*left) >> toUInt64(*right)); } - return toInt(*left) >> toInt(*right); + if (sz <= 4) return toValue(toInt32(*left) >> toInt32(*right)); + return toValue(toInt64(*left) >> toInt64(*right)); } auto less_equal_greater_op(BinaryExpressionAST* ast, @@ -414,80 +525,121 @@ struct ASTInterpreter::ExpressionVisitor { return 0; }; - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return convert(toDouble(*left) <=> toDouble(*right)); - if (control()->is_unsigned(ast->type)) - return convert(toUInt(*left) <=> toUInt(*right)); + if (control()->is_unsigned(type)) { + if (sz <= 4) return convert(toUInt32(*left) <=> toUInt32(*right)); + return convert(toUInt64(*left) <=> toUInt64(*right)); + } - return convert(toInt(*left) <=> toInt(*right)); + if (sz <= 4) return convert(toInt32(*left) <=> toInt32(*right)); + return convert(toInt64(*left) <=> toInt64(*right)); } auto less_equal_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return toDouble(*left) <= toDouble(*right); - if (control()->is_unsigned(ast->type)) - return toUInt(*left) <= toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toUInt(*left) <= toUInt(*right); + return toUInt64(*left) <= toUInt64(*right); + } - return toInt(*left) <= toInt(*right); + if (sz <= 4) return toInt(*left) <= toInt(*right); + return toInt64(*left) <= toInt64(*right); } auto greater_equal_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return toDouble(*left) >= toDouble(*right); - if (control()->is_unsigned(ast->type)) - return toUInt(*left) >= toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toUInt(*left) >= toUInt(*right); + return toUInt64(*left) >= toUInt64(*right); + } - else - return toInt(*left) >= toInt(*right); + if (sz <= 4) return toInt(*left) >= toInt(*right); + return toInt64(*left) >= toInt64(*right); } auto less_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return toDouble(*left) < toDouble(*right); - if (control()->is_unsigned(ast->type)) - return toUInt(*left) < toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toUInt(*left) < toUInt(*right); + return toUInt64(*left) < toUInt64(*right); + } - return toInt(*left) < toInt(*right); + if (sz <= 4) return toInt(*left) < toInt(*right); + return toInt64(*left) < toInt64(*right); } auto greater_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return toDouble(*left) > toDouble(*right); - if (control()->is_unsigned(ast->type)) - return toUInt(*left) > toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toUInt(*left) > toUInt(*right); + return toUInt64(*left) > toUInt64(*right); + } - return toInt(*left) > toInt(*right); + if (sz <= 4) return toInt(*left) > toInt(*right); + return toInt64(*left) > toInt64(*right); } auto equal_equal_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return toDouble(*left) == toDouble(*right); - if (control()->is_unsigned(ast->type)) - return toUInt(*left) == toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toUInt(*left) == toUInt(*right); + return toUInt64(*left) == toUInt64(*right); + } - return toInt(*left) == toInt(*right); + if (sz <= 4) return toInt(*left) == toInt(*right); + return toInt64(*left) == toInt64(*right); } auto exclaim_equal_op(BinaryExpressionAST* ast, const ExpressionResult& left, const ExpressionResult& right) -> ExpressionResult { - if (control()->is_floating_point(ast->type)) + const auto type = ast->leftExpression->type; + const auto sz = memoryLayout()->sizeOf(type); + + if (control()->is_floating_point(type)) return toDouble(*left) != toDouble(*right); - if (control()->is_unsigned(ast->type)) - return toUInt(*left) != toUInt(*right); + if (control()->is_unsigned(type)) { + if (sz <= 4) return toUInt(*left) != toUInt(*right); + return toUInt64(*left) != toUInt64(*right); + } - return toInt(*left) != toInt(*right); + if (sz <= 4) return toInt(*left) != toInt(*right); + return toInt64(*left) != toInt64(*right); } auto amp_op(BinaryExpressionAST* ast, const ExpressionResult& left, @@ -1857,7 +2009,8 @@ auto ASTInterpreter::ExpressionVisitor::operator()( auto ASTInterpreter::ExpressionVisitor::operator()(IntLiteralExpressionAST* ast) -> ExpressionResult { - return ConstValue(ast->literal->integerValue()); + const auto value = static_cast(ast->literal->integerValue()); + return ExpressionResult{std::bit_cast(value)}; } auto ASTInterpreter::ExpressionVisitor::operator()( @@ -1867,7 +2020,7 @@ auto ASTInterpreter::ExpressionVisitor::operator()( auto ASTInterpreter::ExpressionVisitor::operator()( NullptrLiteralExpressionAST* ast) -> ExpressionResult { - return ConstValue(std::uintmax_t(0)); + return ConstValue{std::intmax_t(0)}; } auto ASTInterpreter::ExpressionVisitor::operator()( @@ -2176,19 +2329,19 @@ auto ASTInterpreter::ExpressionVisitor::operator()(AwaitExpressionAST* ast) auto ASTInterpreter::ExpressionVisitor::operator()(SizeofExpressionAST* ast) -> ExpressionResult { if (!ast->expression || !ast->expression->type) return std::nullopt; - if (!control()->memoryLayout()) return std::nullopt; - auto size = control()->memoryLayout()->sizeOf(ast->expression->type); + auto size = memoryLayout()->sizeOf(ast->expression->type); if (!size.has_value()) return std::nullopt; - return std::uintmax_t(*size); + return ExpressionResult( + std::bit_cast(static_cast(*size))); } auto ASTInterpreter::ExpressionVisitor::operator()(SizeofTypeExpressionAST* ast) -> ExpressionResult { if (!ast->typeId || !ast->typeId->type) return std::nullopt; - if (!control()->memoryLayout()) return std::nullopt; - auto size = control()->memoryLayout()->sizeOf(ast->typeId->type); + auto size = memoryLayout()->sizeOf(ast->typeId->type); if (!size.has_value()) return std::nullopt; - return std::uintmax_t(*size); + return ExpressionResult( + std::bit_cast(static_cast(*size))); } auto ASTInterpreter::ExpressionVisitor::operator()(SizeofPackExpressionAST* ast) @@ -2199,10 +2352,10 @@ auto ASTInterpreter::ExpressionVisitor::operator()(SizeofPackExpressionAST* ast) auto ASTInterpreter::ExpressionVisitor::operator()( AlignofTypeExpressionAST* ast) -> ExpressionResult { if (!ast->typeId || !ast->typeId->type) return std::nullopt; - if (!control()->memoryLayout()) return std::nullopt; - auto size = control()->memoryLayout()->alignmentOf(ast->typeId->type); + auto size = memoryLayout()->alignmentOf(ast->typeId->type); if (!size.has_value()) return std::nullopt; - return std::uintmax_t(*size); + return ExpressionResult( + std::bit_cast(static_cast(*size))); } auto ASTInterpreter::ExpressionVisitor::operator()(AlignofExpressionAST* ast) @@ -2210,10 +2363,10 @@ auto ASTInterpreter::ExpressionVisitor::operator()(AlignofExpressionAST* ast) auto expressionResult = accept(ast->expression); if (!ast->expression || !ast->expression->type) return std::nullopt; - if (!control()->memoryLayout()) return std::nullopt; - auto size = control()->memoryLayout()->alignmentOf(ast->expression->type); + auto size = memoryLayout()->alignmentOf(ast->expression->type); if (!size.has_value()) return std::nullopt; - return std::uintmax_t(*size); + return ExpressionResult( + std::bit_cast(static_cast(*size))); } auto ASTInterpreter::ExpressionVisitor::operator()(NoexceptExpressionAST* ast) @@ -2289,7 +2442,7 @@ auto ASTInterpreter::ExpressionVisitor::operator()( if (control()->is_unsigned(ast->type)) { auto result = accept.toUInt(*value); if (!result.has_value()) return std::nullopt; - return result.value(); + return ConstValue{std::bit_cast(result.value())}; } auto result = accept.toInt(*value); @@ -3402,12 +3555,12 @@ auto ASTInterpreter::toBool(const ConstValue& value) -> std::optional { auto ASTInterpreter::toInt(const ConstValue& value) -> std::optional { - return std::visit(ArithmeticCast{}, value); + return std::visit(ToInt{}, value); } auto ASTInterpreter::toUInt(const ConstValue& value) -> std::optional { - return std::visit(ArithmeticCast{}, value); + return std::visit(ToUInt{}, value); } auto ASTInterpreter::toFloat(const ConstValue& value) -> std::optional { diff --git a/src/parser/cxx/const_value.h b/src/parser/cxx/const_value.h index 4e8ae025..b6ebf6dc 100644 --- a/src/parser/cxx/const_value.h +++ b/src/parser/cxx/const_value.h @@ -28,8 +28,7 @@ namespace cxx { -using ConstValue = - std::variant; +using ConstValue = std::variant; } // namespace cxx \ No newline at end of file diff --git a/src/parser/cxx/names.cc b/src/parser/cxx/names.cc index 735177a3..1f347b7f 100644 --- a/src/parser/cxx/names.cc +++ b/src/parser/cxx/names.cc @@ -40,9 +40,6 @@ struct ConstValueHash { auto operator()(std::intmax_t value) const -> std::size_t { return std::hash{}(value); } - auto operator()(std::uintmax_t value) const -> std::size_t { - return std::hash{}(value); - } auto operator()(float value) const -> std::size_t { return std::hash{}(value); } diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index 11491213..25aa1b23 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -6628,7 +6628,7 @@ void Parser::parse_enumerator_list(List*& yyast, if (control_->is_unsigned(type)) { // increment the last value as unsigned if (auto v = interp.toUInt(lastValue.value())) { - lastValue = v.value() + 1; + lastValue = std::bit_cast(v.value() + 1); } else { lastValue = std::nullopt; } diff --git a/src/parser/cxx/symbol_printer.cc b/src/parser/cxx/symbol_printer.cc index bf086c6a..2f4f95c7 100644 --- a/src/parser/cxx/symbol_printer.cc +++ b/src/parser/cxx/symbol_printer.cc @@ -41,10 +41,6 @@ struct GetEnumeratorValue { return std::to_string(value); } - auto operator()(std::uintmax_t value) const -> std::string { - return std::to_string(value); - } - auto operator()(auto x) const -> std::string { return {}; } }; diff --git a/tests/unit_tests/sema/constant_expression_01.cc b/tests/unit_tests/sema/constant_expression_01.cc index ae766447..14582a70 100644 --- a/tests/unit_tests/sema/constant_expression_01.cc +++ b/tests/unit_tests/sema/constant_expression_01.cc @@ -15,3 +15,5 @@ static_assert(1 & 1); static_assert(0b1 | 0b10 == 0b11); static_assert(1 ^ 1 == 0); static_assert(1 + 1 != 2 - 1); + +static_assert(0x8000'0000 * 2 == 0); \ No newline at end of file