diff --git a/duckdb b/duckdb index bf7c548..7c0cc5d 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit bf7c5488f6eacfef3f719f1979aa9c01b0573185 +Subproject commit 7c0cc5d4943dd4fe2176a43818f7dfcc9a541b91 diff --git a/src/expr/expr.cpp b/src/expr/expr.cpp index 51d74d4..eeb2917 100644 --- a/src/expr/expr.cpp +++ b/src/expr/expr.cpp @@ -144,6 +144,13 @@ vortex::dtype::DType *into_vortex_dtype(Arena &arena, const LogicalType &type_, dtype->mutable_primitive()->set_nullable(nullable); dtype->mutable_primitive()->set_type(vortex::dtype::F64); return dtype; + case LogicalTypeId::DECIMAL: { + dtype->mutable_decimal()->set_nullable(nullable); + auto decimal = dtype->mutable_decimal(); + decimal->set_precision(duckdb::DecimalType::GetWidth(type_)); + decimal->set_scale(duckdb::DecimalType::GetScale(type_)); + return dtype; + } case LogicalTypeId::CHAR: case LogicalTypeId::VARCHAR: dtype->mutable_utf8()->set_nullable(nullable); @@ -195,7 +202,7 @@ vortex::scalar::Scalar *into_null_scalar(Arena &arena, LogicalType &logical_type vortex::scalar::Scalar *into_vortex_scalar(Arena &arena, const Value &value, bool nullable) { auto scalar = Arena::Create(&arena); - auto dtype = into_vortex_dtype(arena, value.type().id(), nullable); + auto dtype = into_vortex_dtype(arena, value.type(), nullable); scalar->set_allocated_dtype(dtype); switch (value.type().id()) { @@ -209,25 +216,25 @@ vortex::scalar::Scalar *into_vortex_scalar(Arena &arena, const Value &value, boo return scalar; } case LogicalTypeId::TINYINT: - scalar->mutable_value()->set_int8_value(value.GetValue()); + scalar->mutable_value()->set_int64_value(value.GetValue()); return scalar; case LogicalTypeId::SMALLINT: - scalar->mutable_value()->set_int16_value(value.GetValue()); + scalar->mutable_value()->set_int64_value(value.GetValue()); return scalar; case LogicalTypeId::INTEGER: - scalar->mutable_value()->set_int32_value(value.GetValue()); + scalar->mutable_value()->set_int64_value(value.GetValue()); return scalar; case LogicalTypeId::BIGINT: scalar->mutable_value()->set_int64_value(value.GetValue()); return scalar; case LogicalTypeId::UTINYINT: - scalar->mutable_value()->set_uint8_value(value.GetValue()); + scalar->mutable_value()->set_uint64_value(value.GetValue()); return scalar; case LogicalTypeId::USMALLINT: - scalar->mutable_value()->set_uint16_value(value.GetValue()); + scalar->mutable_value()->set_uint64_value(value.GetValue()); return scalar; case LogicalTypeId::UINTEGER: - scalar->mutable_value()->set_uint32_value(value.GetValue()); + scalar->mutable_value()->set_uint64_value(value.GetValue()); return scalar; case LogicalTypeId::UBIGINT: scalar->mutable_value()->set_uint64_value(value.GetValue()); @@ -238,14 +245,24 @@ vortex::scalar::Scalar *into_vortex_scalar(Arena &arena, const Value &value, boo case LogicalTypeId::DOUBLE: scalar->mutable_value()->set_f64_value(value.GetValue()); return scalar; + case LogicalTypeId::DECIMAL: { + auto huge = value.GetValue(); + uint32_t out[4]; + out[0] = static_cast(huge); + out[1] = static_cast(huge >> 32); + out[2] = static_cast(huge >> 64); + out[3] = static_cast(huge >> 96); + scalar->mutable_value()->set_bytes_value(std::string(reinterpret_cast(out), 8)); + return scalar; + } case LogicalTypeId::VARCHAR: scalar->mutable_value()->set_string_value(value.GetValue()); return scalar; case LogicalTypeId::DATE: - scalar->mutable_value()->set_int32_value(value.GetValue()); + scalar->mutable_value()->set_int64_value(value.GetValue()); return scalar; case LogicalTypeId::TIME: - scalar->mutable_value()->set_int32_value(value.GetValue()); + scalar->mutable_value()->set_int64_value(value.GetValue()); return scalar; case LogicalTypeId::TIMESTAMP_SEC: scalar->mutable_value()->set_int64_value(value.GetValue()); @@ -360,7 +377,7 @@ vortex::expr::Expr *expression_into_vortex_expr(Arena &arena, const duckdb::Expr } case duckdb::ExpressionClass::BOUND_CONSTANT: { auto &dconstant = dexpr.Cast(); - set_literal(arena, Value(dconstant.value), true, expr); + set_literal(arena, dconstant.value, true, expr); return expr; } case duckdb::ExpressionClass::BOUND_COMPARISON: { diff --git a/src/include/vortex_common.hpp b/src/include/vortex_common.hpp index 0dd254d..d817401 100644 --- a/src/include/vortex_common.hpp +++ b/src/include/vortex_common.hpp @@ -28,9 +28,7 @@ struct VortexFileReader { static duckdb::unique_ptr Open(const vx_file_open_options *options) { vx_error *error; auto file = vx_file_open_reader(options, &error); - if (file == nullptr) { - HandleError(error); - } + HandleError(error); return duckdb::make_uniq(file); } @@ -68,10 +66,10 @@ struct VortexArrayStream { duckdb::unique_ptr NextArray() const { vx_error *error; auto array = vx_array_stream_next(array_stream, &error); + HandleError(error); if (array == nullptr) { return nullptr; } - HandleError(error); return duckdb::make_uniq(array); } diff --git a/src/vortex_write.cpp b/src/vortex_write.cpp index 7ca1d22..e43358e 100644 --- a/src/vortex_write.cpp +++ b/src/vortex_write.cpp @@ -2,15 +2,20 @@ #include "vortex_write.hpp" #include "vortex_common.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/multi_file_reader.hpp" #include "duckdb/main/extension_util.hpp" #include "duckdb/function/copy_function.hpp" +#include "duckdb/parser/constraints/not_null_constraint.hpp" namespace duckdb { struct VortexWriteBindData : public TableFunctionData { + //! True is the column is nullable + vector column_nullable; + vector sql_types; vector column_names; }; @@ -35,16 +40,50 @@ void VortexWriteSink(ExecutionContext &context, FunctionData &bind_data, GlobalF input.data[i].Flatten(input.size()); } - auto new_array = - vx_array_append_duckdb_chunk(global_state.array->array, reinterpret_cast(&input)); + auto new_array = vx_array_append_duckdb_chunk( + global_state.array->array, reinterpret_cast(&input), bind.column_nullable.data()); global_state.array = make_uniq(new_array); } +std::vector TableNullability(ClientContext &context, const string &catalog_name, const string &schema, + const string &table) { + auto &catalog = Catalog::GetCatalog(context, catalog_name); + + QueryErrorContext error_context; + // Main is the default schema + auto schema_name = schema != "" ? schema : "main"; + + auto entry = catalog.GetEntry(context, CatalogType::TABLE_ENTRY, schema_name, table, OnEntryNotFound::RETURN_NULL, + error_context); + auto vec = std::vector(); + if (!entry) { + // If there is no entry, it is okay to return all nullable columns. + return vec; + } + + auto &table_entry = entry->Cast(); + for (auto &constraint : table_entry.GetConstraints()) { + if (constraint->type == ConstraintType::NOT_NULL) { + auto &null_constraint = constraint->Cast(); + vec.push_back(null_constraint.index.index); + } + } + return vec; +} + void RegisterVortexWriteFunction(DatabaseInstance &instance) { CopyFunction function("vortex"); function.copy_to_bind = [](ClientContext &context, CopyFunctionBindInput &input, const vector &names, const vector &sql_types) -> unique_ptr { auto result = make_uniq(); + + auto not_null = TableNullability(context, input.info.catalog, input.info.schema, input.info.table); + + result->column_nullable = std::vector(names.size(), true); + for (auto not_null_idx : not_null) { + result->column_nullable[not_null_idx] = false; + } + result->sql_types = sql_types; result->column_names = names; return std::move(result); @@ -64,9 +103,10 @@ void RegisterVortexWriteFunction(DatabaseInstance &instance) { for (auto &col_type : bind.sql_types) { column_types.push_back(reinterpret_cast(&col_type)); } + vx_error *error = nullptr; - auto array = vx_array_create_empty_from_duckdb_table(column_types.data(), column_names.data(), - column_names.size(), &error); + auto array = vx_array_create_empty_from_duckdb_table(column_types.data(), bind.column_nullable.data(), + column_names.data(), column_names.size(), &error); HandleError(error); gstate->array = make_uniq(array); @@ -79,16 +119,9 @@ void RegisterVortexWriteFunction(DatabaseInstance &instance) { function.copy_to_sink = VortexWriteSink; function.copy_to_finalize = [](ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { auto &global_state = gstate.Cast(); - auto opts = vx_file_create_options(); - opts.path = global_state.file_name.c_str(); vx_error *error; - auto file = vx_file_create(&opts, &error); - if (file == nullptr) { - HandleError(error); - } - vx_file_write_array(file, global_state.array->array, &error); + vx_file_write_array(global_state.file_name.c_str(), global_state.array->array, &error); HandleError(error); - vx_file_writer_free(file); }; function.execution_mode = [](bool preserve_insertion_order, bool supports_batch_index) -> CopyFunctionExecutionMode { @@ -99,4 +132,4 @@ void RegisterVortexWriteFunction(DatabaseInstance &instance) { ExtensionUtil::RegisterFunction(instance, function); } -} // namespace duckdb \ No newline at end of file +} // namespace duckdb diff --git a/vortex b/vortex index 3165cfd..fc3196c 160000 --- a/vortex +++ b/vortex @@ -1 +1 @@ -Subproject commit 3165cfd612236d367dfd086c18937cbd6c35273a +Subproject commit fc3196ccd45a8a3451a45b1ec44d923e40c497e7