Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion duckdb
37 changes: 27 additions & 10 deletions src/expr/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ vortex::dtype::DType *into_vortex_dtype(Arena &arena, const LogicalType &type_,
dtype->mutable_primitive()->set_nullable(nullable);
dtype->mutable_primitive()->set_type(vortex::dtype::F64);
return dtype;
case LogicalTypeId::DECIMAL: {
dtype->mutable_decimal()->set_nullable(nullable);
auto decimal = dtype->mutable_decimal();
decimal->set_precision(duckdb::DecimalType::GetWidth(type_));
decimal->set_scale(duckdb::DecimalType::GetScale(type_));
return dtype;
}
case LogicalTypeId::CHAR:
case LogicalTypeId::VARCHAR:
dtype->mutable_utf8()->set_nullable(nullable);
Expand Down Expand Up @@ -195,7 +202,7 @@ vortex::scalar::Scalar *into_null_scalar(Arena &arena, LogicalType &logical_type

vortex::scalar::Scalar *into_vortex_scalar(Arena &arena, const Value &value, bool nullable) {
auto scalar = Arena::Create<vortex::scalar::Scalar>(&arena);
auto dtype = into_vortex_dtype(arena, value.type().id(), nullable);
auto dtype = into_vortex_dtype(arena, value.type(), nullable);
scalar->set_allocated_dtype(dtype);

switch (value.type().id()) {
Expand All @@ -209,25 +216,25 @@ vortex::scalar::Scalar *into_vortex_scalar(Arena &arena, const Value &value, boo
return scalar;
}
case LogicalTypeId::TINYINT:
scalar->mutable_value()->set_int8_value(value.GetValue<int8_t>());
scalar->mutable_value()->set_int64_value(value.GetValue<int8_t>());
return scalar;
case LogicalTypeId::SMALLINT:
scalar->mutable_value()->set_int16_value(value.GetValue<int16_t>());
scalar->mutable_value()->set_int64_value(value.GetValue<int16_t>());
return scalar;
case LogicalTypeId::INTEGER:
scalar->mutable_value()->set_int32_value(value.GetValue<int32_t>());
scalar->mutable_value()->set_int64_value(value.GetValue<int32_t>());
return scalar;
case LogicalTypeId::BIGINT:
scalar->mutable_value()->set_int64_value(value.GetValue<int64_t>());
return scalar;
case LogicalTypeId::UTINYINT:
scalar->mutable_value()->set_uint8_value(value.GetValue<uint8_t>());
scalar->mutable_value()->set_uint64_value(value.GetValue<uint8_t>());
return scalar;
case LogicalTypeId::USMALLINT:
scalar->mutable_value()->set_uint16_value(value.GetValue<uint16_t>());
scalar->mutable_value()->set_uint64_value(value.GetValue<uint16_t>());
return scalar;
case LogicalTypeId::UINTEGER:
scalar->mutable_value()->set_uint32_value(value.GetValue<uint32_t>());
scalar->mutable_value()->set_uint64_value(value.GetValue<uint32_t>());
return scalar;
case LogicalTypeId::UBIGINT:
scalar->mutable_value()->set_uint64_value(value.GetValue<uint64_t>());
Expand All @@ -238,14 +245,24 @@ vortex::scalar::Scalar *into_vortex_scalar(Arena &arena, const Value &value, boo
case LogicalTypeId::DOUBLE:
scalar->mutable_value()->set_f64_value(value.GetValue<double_t>());
return scalar;
case LogicalTypeId::DECIMAL: {
auto huge = value.GetValue<duckdb::hugeint_t>();
uint32_t out[4];
out[0] = static_cast<uint32_t>(huge);
out[1] = static_cast<uint32_t>(huge >> 32);
out[2] = static_cast<uint32_t>(huge >> 64);
out[3] = static_cast<uint32_t>(huge >> 96);
scalar->mutable_value()->set_bytes_value(std::string(reinterpret_cast<char *>(out), 8));
return scalar;
}
case LogicalTypeId::VARCHAR:
scalar->mutable_value()->set_string_value(value.GetValue<string>());
return scalar;
case LogicalTypeId::DATE:
scalar->mutable_value()->set_int32_value(value.GetValue<int32_t>());
scalar->mutable_value()->set_int64_value(value.GetValue<int32_t>());
return scalar;
case LogicalTypeId::TIME:
scalar->mutable_value()->set_int32_value(value.GetValue<int32_t>());
scalar->mutable_value()->set_int64_value(value.GetValue<int32_t>());
return scalar;
case LogicalTypeId::TIMESTAMP_SEC:
scalar->mutable_value()->set_int64_value(value.GetValue<int64_t>());
Expand Down Expand Up @@ -360,7 +377,7 @@ vortex::expr::Expr *expression_into_vortex_expr(Arena &arena, const duckdb::Expr
}
case duckdb::ExpressionClass::BOUND_CONSTANT: {
auto &dconstant = dexpr.Cast<duckdb::BoundConstantExpression>();
set_literal(arena, Value(dconstant.value), true, expr);
set_literal(arena, dconstant.value, true, expr);
return expr;
}
case duckdb::ExpressionClass::BOUND_COMPARISON: {
Expand Down
6 changes: 2 additions & 4 deletions src/include/vortex_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ struct VortexFileReader {
static duckdb::unique_ptr<VortexFileReader> Open(const vx_file_open_options *options) {
vx_error *error;
auto file = vx_file_open_reader(options, &error);
if (file == nullptr) {
HandleError(error);
}
HandleError(error);
return duckdb::make_uniq<VortexFileReader>(file);
}

Expand Down Expand Up @@ -68,10 +66,10 @@ struct VortexArrayStream {
duckdb::unique_ptr<VortexArray> NextArray() const {
vx_error *error;
auto array = vx_array_stream_next(array_stream, &error);
HandleError(error);
if (array == nullptr) {
return nullptr;
}
HandleError(error);
return duckdb::make_uniq<VortexArray>(array);
}

Expand Down
59 changes: 46 additions & 13 deletions src/vortex_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

#include "vortex_write.hpp"
#include "vortex_common.hpp"
#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"

#include "duckdb/common/exception.hpp"
#include "duckdb/common/multi_file_reader.hpp"
#include "duckdb/main/extension_util.hpp"
#include "duckdb/function/copy_function.hpp"
#include "duckdb/parser/constraints/not_null_constraint.hpp"

namespace duckdb {

struct VortexWriteBindData : public TableFunctionData {
//! True is the column is nullable
vector<unsigned char> column_nullable;

vector<LogicalType> sql_types;
vector<string> column_names;
};
Expand All @@ -35,16 +40,50 @@ void VortexWriteSink(ExecutionContext &context, FunctionData &bind_data, GlobalF
input.data[i].Flatten(input.size());
}

auto new_array =
vx_array_append_duckdb_chunk(global_state.array->array, reinterpret_cast<duckdb_data_chunk>(&input));
auto new_array = vx_array_append_duckdb_chunk(
global_state.array->array, reinterpret_cast<duckdb_data_chunk>(&input), bind.column_nullable.data());
global_state.array = make_uniq<VortexArray>(new_array);
}

std::vector<idx_t> TableNullability(ClientContext &context, const string &catalog_name, const string &schema,
const string &table) {
auto &catalog = Catalog::GetCatalog(context, catalog_name);

QueryErrorContext error_context;
// Main is the default schema
auto schema_name = schema != "" ? schema : "main";

auto entry = catalog.GetEntry(context, CatalogType::TABLE_ENTRY, schema_name, table, OnEntryNotFound::RETURN_NULL,
error_context);
auto vec = std::vector<idx_t>();
if (!entry) {
// If there is no entry, it is okay to return all nullable columns.
return vec;
}

auto &table_entry = entry->Cast<TableCatalogEntry>();
for (auto &constraint : table_entry.GetConstraints()) {
if (constraint->type == ConstraintType::NOT_NULL) {
auto &null_constraint = constraint->Cast<NotNullConstraint>();
vec.push_back(null_constraint.index.index);
}
}
return vec;
}

void RegisterVortexWriteFunction(DatabaseInstance &instance) {
CopyFunction function("vortex");
function.copy_to_bind = [](ClientContext &context, CopyFunctionBindInput &input, const vector<string> &names,
const vector<LogicalType> &sql_types) -> unique_ptr<FunctionData> {
auto result = make_uniq<VortexWriteBindData>();

auto not_null = TableNullability(context, input.info.catalog, input.info.schema, input.info.table);

result->column_nullable = std::vector<unsigned char>(names.size(), true);
for (auto not_null_idx : not_null) {
result->column_nullable[not_null_idx] = false;
}

result->sql_types = sql_types;
result->column_names = names;
return std::move(result);
Expand All @@ -64,9 +103,10 @@ void RegisterVortexWriteFunction(DatabaseInstance &instance) {
for (auto &col_type : bind.sql_types) {
column_types.push_back(reinterpret_cast<duckdb_logical_type>(&col_type));
}

vx_error *error = nullptr;
auto array = vx_array_create_empty_from_duckdb_table(column_types.data(), column_names.data(),
column_names.size(), &error);
auto array = vx_array_create_empty_from_duckdb_table(column_types.data(), bind.column_nullable.data(),
column_names.data(), column_names.size(), &error);
HandleError(error);

gstate->array = make_uniq<VortexArray>(array);
Expand All @@ -79,16 +119,9 @@ void RegisterVortexWriteFunction(DatabaseInstance &instance) {
function.copy_to_sink = VortexWriteSink;
function.copy_to_finalize = [](ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) {
auto &global_state = gstate.Cast<VortexWriteGlobalData>();
auto opts = vx_file_create_options();
opts.path = global_state.file_name.c_str();
vx_error *error;
auto file = vx_file_create(&opts, &error);
if (file == nullptr) {
HandleError(error);
}
vx_file_write_array(file, global_state.array->array, &error);
vx_file_write_array(global_state.file_name.c_str(), global_state.array->array, &error);
HandleError(error);
vx_file_writer_free(file);
};
function.execution_mode = [](bool preserve_insertion_order,
bool supports_batch_index) -> CopyFunctionExecutionMode {
Expand All @@ -99,4 +132,4 @@ void RegisterVortexWriteFunction(DatabaseInstance &instance) {
ExtensionUtil::RegisterFunction(instance, function);
}

} // namespace duckdb
} // namespace duckdb
2 changes: 1 addition & 1 deletion vortex
Loading