Skip to content

Commit b27b2ad

Browse files
authored
Fix csv writer (#12)
This pool request fixes a bug where the CSV format in partition mode did not honor the `write_to_output` option and would write the column header anyway. Also, the logic for skipping some columns has been moved to the model writer.
1 parent c43cbf5 commit b27b2ad

File tree

5 files changed

+107
-38
lines changed

5 files changed

+107
-38
lines changed

internal/generator/output/general/model_writer.go

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type ModelWriter struct {
3333
basePath string
3434
continueGeneration bool
3535

36-
numberColumnsToDiscard int
36+
columnsToDiscard map[string]struct{}
3737
partitionColumnsIndexes []int
3838
orderedColumnNames []string
3939

@@ -61,12 +61,12 @@ func newModelWriter(
6161
orderedColumnNames = append(orderedColumnNames, column.Name)
6262
}
6363

64-
numberColumnsToDiscard := 0
64+
columnsToDiscard := make(map[string]struct{})
6565
partitionOrderedColumnNames := make([]string, 0, len(model.PartitionColumns))
6666

6767
for _, column := range model.PartitionColumns {
6868
if !column.WriteToOutput {
69-
numberColumnsToDiscard++
69+
columnsToDiscard[column.Name] = struct{}{}
7070
}
7171

7272
partitionOrderedColumnNames = append(partitionOrderedColumnNames, column.Name)
@@ -97,7 +97,7 @@ func newModelWriter(
9797
config: config,
9898
basePath: basePath,
9999
continueGeneration: continueGeneration,
100-
numberColumnsToDiscard: numberColumnsToDiscard,
100+
columnsToDiscard: columnsToDiscard,
101101
partitionColumnsIndexes: partitionColumnsIndexes,
102102
orderedColumnNames: orderedColumnNames,
103103
checkpointTicker: ticker,
@@ -192,7 +192,7 @@ func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) err
192192

193193
// discard not writeable columns
194194
sendRow := &models.DataRow{
195-
Values: row.Values[:len(row.Values)-w.numberColumnsToDiscard],
195+
Values: row.Values[:len(row.Values)-len(w.columnsToDiscard)],
196196
}
197197

198198
if err := dataWriter.WriteRow(sendRow); err != nil {
@@ -237,19 +237,30 @@ func (w *ModelWriter) newWriter(ctx context.Context, outPath string) (writer.Wri
237237
var dataWriter writer.Writer
238238

239239
switch w.config.Type {
240+
case "devnull":
241+
dataWriter = devnull.NewWriter(
242+
w.model,
243+
w.config.DevNullParams,
244+
)
240245
case "csv":
241246
dataWriter = csv.NewWriter(
242247
ctx,
243248
w.model,
244249
w.config.CSVParams,
250+
w.columnsToDiscard,
245251
outPath,
246252
w.continueGeneration,
247253
w.writtenRowsChan,
248254
)
249-
case "devnull":
250-
dataWriter = devnull.NewWriter(
255+
case "parquet":
256+
dataWriter = parquet.NewWriter(
251257
w.model,
252-
w.config.DevNullParams,
258+
w.config.ParquetParams,
259+
w.columnsToDiscard,
260+
parquet.NewFileSystem(),
261+
outPath,
262+
w.continueGeneration,
263+
w.writtenRowsChan,
253264
)
254265
case "http":
255266
dataWriter = http.NewWriter(
@@ -265,15 +276,6 @@ func (w *ModelWriter) newWriter(ctx context.Context, outPath string) (writer.Wri
265276
w.config.TCSParams,
266277
w.writtenRowsChan,
267278
)
268-
case "parquet":
269-
dataWriter = parquet.NewWriter(
270-
w.model,
271-
w.config.ParquetParams,
272-
parquet.NewFileSystem(),
273-
outPath,
274-
w.continueGeneration,
275-
w.writtenRowsChan,
276-
)
277279
default:
278280
return nil, errors.Errorf("unknown output type: %q", w.config.Type)
279281
}

internal/generator/output/general/writer/csv/csv.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type Writer struct {
3333
ctx context.Context //nolint:containedctx
3434

3535
model *models.Model
36+
columnsToDiscard map[string]struct{}
3637
config *models.CSVConfig
3738
outputPath string
3839
continueGeneration bool
@@ -58,13 +59,15 @@ func NewWriter(
5859
ctx context.Context,
5960
model *models.Model,
6061
config *models.CSVConfig,
62+
columnsToDiscard map[string]struct{},
6163
outputPath string,
6264
continueGeneration bool,
6365
writtenRowsChan chan<- uint64,
6466
) *Writer {
6567
return &Writer{
6668
ctx: ctx,
6769
model: model,
70+
columnsToDiscard: columnsToDiscard,
6871
config: config,
6972
outputPath: outputPath,
7073
continueGeneration: continueGeneration,
@@ -363,12 +366,7 @@ func (w *Writer) replaceFile(fileName string) error {
363366
w.fileDescriptor = file
364367

365368
if !w.config.WithoutHeaders && (!w.continueGeneration || !fileExists) {
366-
header := make([]string, len(w.model.Columns))
367-
for i, column := range w.model.Columns {
368-
header[i] = column.Name
369-
}
370-
371-
err = w.csvWriter.Write(header)
369+
err = w.csvWriter.Write(w.getHeaders())
372370
if err != nil {
373371
return errors.New(err.Error())
374372
}
@@ -377,6 +375,20 @@ func (w *Writer) replaceFile(fileName string) error {
377375
return nil
378376
}
379377

378+
func (w *Writer) getHeaders() []string {
379+
headers := make([]string, 0, len(w.model.Columns)-len(w.columnsToDiscard))
380+
381+
for _, column := range w.model.Columns {
382+
if _, exists := w.columnsToDiscard[column.Name]; exists {
383+
continue
384+
}
385+
386+
headers = append(headers, column.Name)
387+
}
388+
389+
return headers
390+
}
391+
380392
// WriteRow function sends row to internal queue.
381393
func (w *Writer) WriteRow(row *models.DataRow) error {
382394
select {

internal/generator/output/general/writer/csv/csv_test.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,15 @@ func TestWriteRow(t *testing.T) {
130130

131131
csvConfig.WithoutHeaders = tc.withoutHeaders
132132

133-
csvWriter := NewWriter(context.Background(), tc.model, csvConfig, "./", false, nil)
133+
csvWriter := NewWriter(
134+
context.Background(),
135+
tc.model,
136+
csvConfig,
137+
getColumnsToDiscard(tc.model.PartitionColumns),
138+
"./",
139+
false,
140+
nil,
141+
)
134142

135143
err := csvWriter.Init()
136144
require.NoError(t, err)
@@ -307,7 +315,15 @@ func TestWriteToCorrectFiles(t *testing.T) {
307315
}
308316

309317
write := func(from, to int, continueGeneration bool) {
310-
writer := NewWriter(context.Background(), model, config, dir, continueGeneration, nil)
318+
writer := NewWriter(
319+
context.Background(),
320+
model,
321+
config,
322+
getColumnsToDiscard(model.PartitionColumns),
323+
dir,
324+
continueGeneration,
325+
nil,
326+
)
311327
require.NoError(t, writer.Init())
312328

313329
for i := from; i < to; i++ {
@@ -370,3 +386,15 @@ func getFileNumber(rows, rowsPerFile int) int {
370386

371387
return fileNumber
372388
}
389+
390+
func getColumnsToDiscard(partitionColumns []*models.PartitionColumn) map[string]struct{} {
391+
columnsToDiscard := make(map[string]struct{})
392+
393+
for _, column := range partitionColumns {
394+
if !column.WriteToOutput {
395+
columnsToDiscard[column.Name] = struct{}{}
396+
}
397+
}
398+
399+
return columnsToDiscard
400+
}

internal/generator/output/general/writer/parquet/parquet.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ var _ writer.Writer = (*Writer)(nil)
5858
// Writer type is implementation of Writer to parquet file.
5959
type Writer struct {
6060
model *models.Model
61+
columnsToDiscard map[string]struct{}
6162
config *models.ParquetConfig
6263
outputPath string
6364
continueGeneration bool
@@ -91,13 +92,15 @@ type FileSystem interface {
9192
func NewWriter(
9293
model *models.Model,
9394
config *models.ParquetConfig,
95+
columnsToDiscard map[string]struct{},
9496
fs FileSystem,
9597
outputPath string,
9698
continueGeneration bool,
9799
writtenRowsChan chan<- uint64,
98100
) *Writer {
99101
return &Writer{
100102
model: model,
103+
columnsToDiscard: columnsToDiscard,
101104
config: config,
102105
outputPath: outputPath,
103106
continueGeneration: continueGeneration,
@@ -122,14 +125,9 @@ func (w *Writer) generateModelSchema() (*arrow.Schema, []parquet.WriterProperty,
122125

123126
arrowFields := make([]arrow.Field, 0, len(w.model.Columns))
124127

125-
partitionColumnsByName := map[string]*models.PartitionColumn{}
126-
for _, column := range w.model.PartitionColumns {
127-
partitionColumnsByName[column.Name] = column
128-
}
129-
130128
for _, column := range w.model.Columns {
131-
colSettings, ok := partitionColumnsByName[column.Name]
132-
if ok && !colSettings.WriteToOutput { // filter partition columns in schema
129+
// filter partition columns in schema
130+
if _, exists := w.columnsToDiscard[column.Name]; exists {
133131
continue
134132
}
135133

internal/generator/output/general/writer/parquet/parquet_test.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,11 @@ func TestGetModelSchema(t *testing.T) {
460460
require.NotEqual(t, "", tc.model.Name)
461461

462462
writer := &Writer{
463-
model: tc.model,
464-
config: tc.cfg,
465-
fs: fsMock,
466-
outputPath: "./",
463+
model: tc.model,
464+
columnsToDiscard: getColumnsToDiscard(tc.model.PartitionColumns),
465+
config: tc.cfg,
466+
fs: fsMock,
467+
outputPath: "./",
467468
}
468469

469470
modelSchemaPointer, writerProperties, err := writer.generateModelSchema()
@@ -616,7 +617,15 @@ func TestWriteRow(t *testing.T) {
616617
// WHEN
617618

618619
fsMock := newFileSystemMock()
619-
parquetWriter := NewWriter(tc.model, parquetConfig, fsMock, "./", false, nil)
620+
parquetWriter := NewWriter(
621+
tc.model,
622+
parquetConfig,
623+
getColumnsToDiscard(tc.model.PartitionColumns),
624+
fsMock,
625+
"./",
626+
false,
627+
nil,
628+
)
620629

621630
err := parquetWriter.Init()
622631
require.NoError(t, err)
@@ -825,7 +834,15 @@ func TestWriteToCorrectFiles(t *testing.T) {
825834
fsMock := newFileSystemMock()
826835

827836
write := func(from, to int, continueGeneration bool) {
828-
writer := NewWriter(model, config, fsMock, dir, continueGeneration, nil)
837+
writer := NewWriter(
838+
model,
839+
config,
840+
getColumnsToDiscard(model.PartitionColumns),
841+
fsMock,
842+
dir,
843+
continueGeneration,
844+
nil,
845+
)
829846
require.NoError(t, writer.Init())
830847

831848
for i := from; i < to; i++ {
@@ -914,3 +931,15 @@ func getExpected(rows []*models.DataRow, rowsPerFile uint64, writersCount int) (
914931

915932
return expectedFiles, expectedData
916933
}
934+
935+
func getColumnsToDiscard(partitionColumns []*models.PartitionColumn) map[string]struct{} {
936+
columnsToDiscard := make(map[string]struct{})
937+
938+
for _, column := range partitionColumns {
939+
if !column.WriteToOutput {
940+
columnsToDiscard[column.Name] = struct{}{}
941+
}
942+
}
943+
944+
return columnsToDiscard
945+
}

0 commit comments

Comments
 (0)