Skip to content

Commit 5d59785

Browse files
committed
YT-25045: Add support for Arrow complex types
commit_hash:5758a66421fefa3d48e88762c0e6a4bf87b9d5aa
1 parent ea8dd0f commit 5d59785

File tree

6 files changed

+1286
-91
lines changed

6 files changed

+1286
-91
lines changed

yt/yt/client/formats/config.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ void TArrowFormatConfig::Register(TRegistrar registrar)
364364
{
365365
registrar.Parameter("enable_tz_index", &TThis::EnableTzIndex)
366366
.Default(false);
367+
registrar.Parameter("enable_complex_types", &TThis::EnableComplexTypes)
368+
.Default(false);
367369
}
368370

369371
////////////////////////////////////////////////////////////////////////////////

yt/yt/client/formats/config.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ struct TArrowFormatConfig
425425
//! Return the timezone as index.
426426
bool EnableTzIndex;
427427

428+
//! Write YSON-encoded complex types as Arrow types.
429+
bool EnableComplexTypes;
430+
428431
REGISTER_YSON_STRUCT(TArrowFormatConfig);
429432

430433
static void Register(TRegistrar registrar);

yt/yt/client/table_client/columnar.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -100,32 +100,37 @@ void CopyBitmapRangeToBitmapImpl(
100100
} else {
101101
::memcpy(beginByteOutput, beginByteInput, byteCount);
102102
}
103-
return;
104-
}
105-
106-
auto buildOutputQWord = [&] (ui64 qword1, ui64 qword2) {
107-
qword1 >>= qwordShift;
108-
qword2 &= (1ULL << qwordShift) - 1;
109-
qword2 <<= qwordCoshift;
110-
return MaybeNegateValue<Negate>(qword1 | qword2);
111-
};
103+
} else {
104+
auto buildOutputQWord = [&] (ui64 qword1, ui64 qword2) {
105+
qword1 >>= qwordShift;
106+
qword2 &= (1ULL << qwordShift) - 1;
107+
qword2 <<= qwordCoshift;
108+
return MaybeNegateValue<Negate>(qword1 | qword2);
109+
};
110+
111+
// Head
112+
while (currentQwordInput < endQwordInput - 1) {
113+
auto qword1 = currentQwordInput[0];
114+
auto qword2 = currentQwordInput[1];
115+
*currentQwordOutput = buildOutputQWord(qword1, qword2);
116+
++currentQwordInput;
117+
++currentQwordOutput;
118+
}
112119

113-
// Head
114-
while (currentQwordInput < endQwordInput - 1) {
115-
auto qword1 = currentQwordInput[0];
116-
auto qword2 = currentQwordInput[1];
117-
*currentQwordOutput = buildOutputQWord(qword1, qword2);
118-
++currentQwordInput;
119-
++currentQwordOutput;
120+
// Tail
121+
while (currentQwordInput <= endQwordInput) {
122+
auto qword1 = SafeReadQword(currentQwordInput, bitmap.End());
123+
auto qword2 = SafeReadQword(currentQwordInput + 1, bitmap.End());
124+
SafeWriteQword(currentQwordOutput, dst.End(), buildOutputQWord(qword1, qword2));
125+
++currentQwordInput;
126+
++currentQwordOutput;
127+
}
120128
}
121129

122-
// Tail
123-
while (currentQwordInput <= endQwordInput) {
124-
auto qword1 = SafeReadQword(currentQwordInput, bitmap.End());
125-
auto qword2 = SafeReadQword(currentQwordInput + 1, bitmap.End());
126-
SafeWriteQword(currentQwordOutput, dst.End(), buildOutputQWord(qword1, qword2));
127-
++currentQwordInput;
128-
++currentQwordOutput;
130+
// Unset all unused bits in the last byte.
131+
if (byteCount > 0 && (bitCount & 7) != 0) {
132+
auto* lastOutputByte = reinterpret_cast<ui8*>(dst.Begin()) + byteCount - 1;
133+
*lastOutputByte &= MaskLowerBits(bitCount & 7);
129134
}
130135
}
131136

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <yt/yt/core/misc/common.h>
4+
5+
namespace NYT::NFormats {
6+
7+
////////////////////////////////////////////////////////////////////////////////
8+
9+
static const std::string YtTypeMetadataKey = "YtType";
10+
static const std::string YtTypeMetadataValueEmptyStruct = "emptyStruct";
11+
static const std::string YtTypeMetadataValueNestedOptional = "nestedOptional";
12+
13+
////////////////////////////////////////////////////////////////////////////////
14+
15+
} // namespace NYT::NFormats

yt/yt/library/formats/arrow_parser.cpp

Lines changed: 122 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "arrow_parser.h"
2+
#include "arrow_metadata_constants.h"
23

34
#include <yt/yt/client/formats/parser.h>
45

@@ -635,6 +636,31 @@ i64 CheckAndTransformTimestamp(i64 arrowValue, arrow20::TimeUnit::type timeUnit,
635636

636637
////////////////////////////////////////////////////////////////////////////////
637638

639+
std::optional<std::string> GetYtTypeFromMetadata(const std::shared_ptr<arrow20::Field>& schemaField)
640+
{
641+
auto columnMetadata = schemaField->metadata();
642+
if (!columnMetadata) {
643+
return std::nullopt;
644+
}
645+
auto valueResult = columnMetadata->Get(YtTypeMetadataKey);
646+
if (valueResult.ok()) {
647+
return *valueResult;
648+
}
649+
return std::nullopt;
650+
}
651+
652+
bool HasEmptyStructTypeInMetadata(const std::shared_ptr<arrow20::Field>& schemaField)
653+
{
654+
return GetYtTypeFromMetadata(schemaField) == YtTypeMetadataValueEmptyStruct;
655+
}
656+
657+
bool HasNestedOptionalTypeInMetadata(const std::shared_ptr<arrow20::Field>& schemaField)
658+
{
659+
return GetYtTypeFromMetadata(schemaField) == YtTypeMetadataValueNestedOptional;
660+
}
661+
662+
////////////////////////////////////////////////////////////////////////////////
663+
638664
class TArraySimpleVisitor
639665
: public arrow20::TypeVisitor
640666
{
@@ -1076,11 +1102,13 @@ class TArrayCompositeVisitor
10761102
TArrayCompositeVisitor(
10771103
TLogicalTypePtr ytType,
10781104
const std::shared_ptr<arrow20::Array>& array,
1105+
const std::shared_ptr<arrow20::Field>& schemaField,
10791106
NYson::TCheckedInDebugYsonTokenWriter* writer,
10801107
int rowIndex)
10811108
: YTType_(DenullifyLogicalType(ytType))
10821109
, RowIndex_(rowIndex)
10831110
, Array_(array)
1111+
, SchemaField_(schemaField)
10841112
, Writer_(writer)
10851113
{
10861114
YT_VERIFY(writer != nullptr);
@@ -1098,6 +1126,7 @@ class TArrayCompositeVisitor
10981126
TArrayCompositeVisitor visitor(
10991127
YTType_,
11001128
dictionary,
1129+
SchemaField_,
11011130
Writer_,
11021131
dictionaryArrayColumn->GetValueIndex(RowIndex_));
11031132
ThrowOnError(dictionary->type()->Accept(&visitor));
@@ -1293,6 +1322,7 @@ class TArrayCompositeVisitor
12931322
const int RowIndex_;
12941323

12951324
std::shared_ptr<arrow20::Array> Array_;
1325+
std::shared_ptr<arrow20::Field> SchemaField_;
12961326
NYson::TCheckedInDebugYsonTokenWriter* Writer_ = nullptr;
12971327

12981328
template <typename ArrayType>
@@ -1513,7 +1543,7 @@ class TArrayCompositeVisitor
15131543

15141544
auto listValue = array->value_slice(RowIndex_);
15151545
for (int offset = 0; offset < listValue->length(); ++offset) {
1516-
TArrayCompositeVisitor visitor(YTType_->AsListTypeRef().GetElement(), listValue, Writer_, offset);
1546+
TArrayCompositeVisitor visitor(YTType_->AsListTypeRef().GetElement(), listValue, array->type()->field(0), Writer_, offset);
15171547
try {
15181548
ThrowOnError(listValue->type()->Accept(&visitor));
15191549
} catch (const std::exception& ex) {
@@ -1548,12 +1578,15 @@ class TArrayCompositeVisitor
15481578
auto keyList = allKeys->Slice(offset, length);
15491579
auto valueList = allValues->Slice(offset, length);
15501580

1581+
// Map is represented as list of pairs.
1582+
auto pairType = array->type()->field(0)->type();
1583+
15511584
Writer_->WriteBeginList();
15521585

15531586
for (int offset = 0; offset < keyList->length(); ++offset) {
15541587
Writer_->WriteBeginList();
15551588

1556-
TArrayCompositeVisitor keyVisitor(YTType_->AsDictTypeRef().GetKey(), keyList, Writer_, offset);
1589+
TArrayCompositeVisitor keyVisitor(YTType_->AsDictTypeRef().GetKey(), keyList, pairType->field(0), Writer_, offset);
15571590
try {
15581591
ThrowOnError(keyList->type()->Accept(&keyVisitor));
15591592
} catch (const std::exception& ex) {
@@ -1564,7 +1597,7 @@ class TArrayCompositeVisitor
15641597

15651598
Writer_->WriteItemSeparator();
15661599

1567-
TArrayCompositeVisitor valueVisitor(YTType_->AsDictTypeRef().GetValue(), valueList, Writer_, offset);
1600+
TArrayCompositeVisitor valueVisitor(YTType_->AsDictTypeRef().GetValue(), valueList, pairType->field(1), Writer_, offset);
15681601
try {
15691602
ThrowOnError(valueList->type()->Accept(&valueVisitor));
15701603
} catch (const std::exception& ex) {
@@ -1584,29 +1617,41 @@ class TArrayCompositeVisitor
15841617
return arrow20::Status::OK();
15851618
}
15861619

1587-
arrow20::Status ParseStruct()
1620+
void ParseStructForStruct()
15881621
{
1589-
if (YTType_->GetMetatype() != ELogicalMetatype::Struct) {
1590-
THROW_ERROR_EXCEPTION("Unexpected arrow type \"struct\" for YT metatype %Qlv",
1591-
YTType_->GetMetatype());
1592-
}
15931622
auto array = std::static_pointer_cast<arrow20::StructArray>(Array_);
15941623
if (array->IsNull(RowIndex_)) {
15951624
Writer_->WriteEntity();
15961625
} else {
15971626
Writer_->WriteBeginList();
1598-
auto structFields = YTType_->AsStructTypeRef().GetFields();
1599-
if (std::ssize(structFields) != array->num_fields()) {
1600-
THROW_ERROR_EXCEPTION("The number of fields in the Arrow \"struct\" type does not match the number of fields in the YT \"struct\" type")
1601-
<< TErrorAttribute("arrow_field_count", array->num_fields())
1602-
<< TErrorAttribute("yt_field_count", std::ssize(structFields));
1627+
1628+
const auto& structFields = YTType_->AsStructTypeRef().GetFields();
1629+
1630+
if (structFields.empty()) {
1631+
if (!HasEmptyStructTypeInMetadata(SchemaField_)) {
1632+
THROW_ERROR_EXCEPTION(
1633+
"YT \"struct\" type has no fields, but no metadata found with the key \'%v\' and the value \'%v\'",
1634+
YtTypeMetadataKey,
1635+
YtTypeMetadataValueEmptyStruct);
1636+
}
1637+
if (array->num_fields() != 1 && array->field(0)->type()->Equals(arrow20::null())) {
1638+
THROW_ERROR_EXCEPTION("YT \"struct\" type has no fields, but Arrow \"struct\" type does not have a single dummy null field");
1639+
}
1640+
} else {
1641+
if (std::ssize(structFields) != array->num_fields()) {
1642+
THROW_ERROR_EXCEPTION("The number of fields in the Arrow \"struct\" type does not match the number of fields in the YT \"struct\" type")
1643+
<< TErrorAttribute("arrow_field_count", array->num_fields())
1644+
<< TErrorAttribute("yt_field_count", std::ssize(structFields));
1645+
}
16031646
}
1647+
1648+
const auto& structType = std::static_pointer_cast<arrow20::StructType>(array->type());
16041649
for (const auto& field : structFields) {
16051650
auto arrowField = array->GetFieldByName(field.Name);
16061651
if (!arrowField) {
16071652
THROW_ERROR_EXCEPTION("Field %Qv is not found in arrow type \"struct\"", field.Name);
16081653
}
1609-
TArrayCompositeVisitor visitor(field.Type, arrowField, Writer_, RowIndex_);
1654+
TArrayCompositeVisitor visitor(field.Type, arrowField, structType->GetFieldByName(field.Name), Writer_, RowIndex_);
16101655
try {
16111656
ThrowOnError(arrowField->type()->Accept(&visitor));
16121657
} catch (const std::exception& ex) {
@@ -1619,6 +1664,53 @@ class TArrayCompositeVisitor
16191664

16201665
Writer_->WriteEndList();
16211666
}
1667+
}
1668+
1669+
void ParseStructForOptional()
1670+
{
1671+
auto array = std::static_pointer_cast<arrow20::StructArray>(Array_);
1672+
if (array->IsNull(RowIndex_)) {
1673+
Writer_->WriteEntity();
1674+
} else {
1675+
Writer_->WriteBeginList();
1676+
if (!HasNestedOptionalTypeInMetadata(SchemaField_)) {
1677+
THROW_ERROR_EXCEPTION(
1678+
"The element of YT \"optional\" type is nullable, but no metadata found with the key \'%v\' and the value \'%v\'",
1679+
YtTypeMetadataKey,
1680+
YtTypeMetadataValueNestedOptional);
1681+
}
1682+
if (array->num_fields() != 1) {
1683+
THROW_ERROR_EXCEPTION("The number of fields in the Arrow \"struct\" type is not equal to 1 for the YT \"optional\" type")
1684+
<< TErrorAttribute("arrow_field_count", array->num_fields());
1685+
}
1686+
1687+
auto arrowField = array->field(0);
1688+
TArrayCompositeVisitor visitor(YTType_->GetElement(), arrowField, array->type()->field(0), Writer_, RowIndex_);
1689+
try {
1690+
ThrowOnError(arrowField->type()->Accept(&visitor));
1691+
} catch (const std::exception& ex) {
1692+
THROW_ERROR_EXCEPTION("Failed to parse arrow struct field for the YT \"optional\" type")
1693+
<< ex;
1694+
}
1695+
1696+
Writer_->WriteItemSeparator();
1697+
Writer_->WriteEndList();
1698+
}
1699+
}
1700+
1701+
arrow20::Status ParseStruct()
1702+
{
1703+
switch (YTType_->GetMetatype()) {
1704+
case ELogicalMetatype::Struct:
1705+
ParseStructForStruct();
1706+
break;
1707+
case ELogicalMetatype::Optional:
1708+
ParseStructForOptional();
1709+
break;
1710+
default:
1711+
THROW_ERROR_EXCEPTION("Unexpected arrow type \"struct\" for YT metatype %Qlv",
1712+
YTType_->GetMetatype());
1713+
}
16221714
return arrow20::Status::OK();
16231715
}
16241716

@@ -1650,6 +1742,7 @@ void PrepareArrayForComplexType(
16501742
const TLogicalTypePtr& denullifiedLogicalType,
16511743
const std::shared_ptr<TChunkedOutputStream>& bufferForStringLikeValues,
16521744
const std::shared_ptr<arrow20::Array>& column,
1745+
const std::shared_ptr<arrow20::Field>& schemaField,
16531746
TUnversionedRowValues& rowValues,
16541747
int columnId)
16551748
{
@@ -1699,6 +1792,16 @@ void PrepareArrayForComplexType(
16991792
break;
17001793

17011794
case ELogicalMetatype::Optional:
1795+
CheckArrowType(
1796+
metatype,
1797+
{
1798+
arrow20::Type::STRUCT,
1799+
arrow20::Type::BINARY
1800+
},
1801+
column->type()->name(),
1802+
column->type_id());
1803+
break;
1804+
17021805
case ELogicalMetatype::Tuple:
17031806
case ELogicalMetatype::VariantTuple:
17041807
case ELogicalMetatype::VariantStruct:
@@ -1735,7 +1838,7 @@ void PrepareArrayForComplexType(
17351838
TBufferOutput out(valueBuffer);
17361839
NYson::TCheckedInDebugYsonTokenWriter writer(&out);
17371840

1738-
TArrayCompositeVisitor visitor(denullifiedLogicalType, column, &writer, rowIndex);
1841+
TArrayCompositeVisitor visitor(denullifiedLogicalType, column, schemaField, &writer, rowIndex);
17391842

17401843
ThrowOnError(column->type()->Accept(&visitor));
17411844

@@ -1760,14 +1863,15 @@ void PrepareArray(
17601863
const TLogicalTypePtr& denullifiedLogicalType,
17611864
const std::shared_ptr<TChunkedOutputStream>& bufferForStringLikeValues,
17621865
const std::shared_ptr<arrow20::Array>& column,
1866+
const std::shared_ptr<arrow20::Field>& schemaField,
17631867
TUnversionedRowValues& rowValues,
17641868
int columnId)
17651869
{
17661870
if (column->type()->id() == arrow20::Type::DICTIONARY) {
17671871
auto dictionaryArrayColumn = std::static_pointer_cast<arrow20::DictionaryArray>(column);
17681872
auto dictionary = dictionaryArrayColumn->dictionary();
17691873
TUnversionedRowValues dictionaryValues(dictionary->length());
1770-
PrepareArray(denullifiedLogicalType, bufferForStringLikeValues, dictionary, dictionaryValues, columnId);
1874+
PrepareArray(denullifiedLogicalType, bufferForStringLikeValues, dictionary, schemaField, dictionaryValues, columnId);
17711875

17721876
for (int offset = 0; offset < std::ssize(rowValues); ++offset) {
17731877
if (dictionaryArrayColumn->IsNull(offset)) {
@@ -1802,6 +1906,7 @@ void PrepareArray(
18021906
denullifiedLogicalType,
18031907
bufferForStringLikeValues,
18041908
column,
1909+
schemaField,
18051910
rowValues,
18061911
columnId);
18071912

@@ -1869,6 +1974,7 @@ class TListener
18691974
denullifiedColumnType,
18701975
bufferForStringLikeValues,
18711976
batch->column(columnIndex),
1977+
batch->schema()->field(columnIndex),
18721978
rowsValues[columnIndex],
18731979
columnId);
18741980
} catch (const std::exception& ex) {

0 commit comments

Comments
 (0)