Skip to content

Commit 0aa0399

Browse files
authored
csv: Add support for exporting and decoding CSV slices (#101)
Enhanced CSV export and decode functions to handle slices of primitive types (e.g., []string, []int) without requiring field names. Updated writer and reader logic to properly marshal and unmarshal single-column CSV data. Added corresponding tests to verify slice handling.
1 parent a888977 commit 0aa0399

File tree

5 files changed

+115
-31
lines changed

5 files changed

+115
-31
lines changed

csv/export.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package csv
22

33
import (
4-
"fmt"
54
"io"
65
"os"
6+
"reflect"
77
)
88

99
// Export writes slice as csv format with fieldnames to writer w.
@@ -37,17 +37,30 @@ func ExportUTF8File[S ~[]E, E any](fieldnames []string, slice S, file string) er
3737
}
3838

3939
func export[S ~[]E, E any](fieldnames []string, slice S, w io.Writer, utf8bom bool) (err error) {
40-
csvWriter := NewWriter(w, utf8bom)
40+
var fields any
4141
if len(fieldnames) == 0 {
42-
if len(slice) == 0 {
43-
return fmt.Errorf("can't get struct fieldnames from zero length slice")
42+
t := reflect.TypeFor[E]()
43+
for t.Kind() == reflect.Pointer {
44+
t = t.Elem()
45+
}
46+
if kind := t.Kind(); kind == reflect.Struct || kind == reflect.Map {
47+
fields = reflect.Zero(t).Interface()
4448
}
45-
err = csvWriter.WriteFields(slice[0])
4649
} else {
47-
err = csvWriter.WriteFields(fieldnames)
50+
fields = fieldnames
4851
}
49-
if err != nil {
50-
return
52+
csvWriter := NewWriter(w, utf8bom)
53+
if fields != nil {
54+
if err = csvWriter.WriteFields(fields); err != nil {
55+
return
56+
}
57+
} else {
58+
csvWriter.fieldsWritten = true
59+
csvWriter.zero = make([]string, 1)
60+
csvWriter.pool.New = func() *[]string {
61+
s := make([]string, 1)
62+
return &s
63+
}
5164
}
5265
return csvWriter.WriteAll(slice)
5366
}

csv/export_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ func testExport[E any](t *testing.T, tc testcase[E], result string) {
1515
var b bytes.Buffer
1616
if err := Export(tc.fieldnames, tc.slice, &b); err != nil {
1717
t.Error(tc.name, err)
18+
return
1819
}
1920
if r := b.String(); r != result {
2021
t.Errorf("%s expected %q; got %q", tc.name, result, r)
@@ -111,3 +112,24 @@ a,b
111112
t.Errorf("expected %q; got %q", result, r)
112113
}
113114
}
115+
116+
func TestExportSlice(t *testing.T) {
117+
result := `1
118+
2
119+
3
120+
`
121+
var b bytes.Buffer
122+
if err := Export(nil, []string{"1", "2", "3"}, &b); err != nil {
123+
t.Fatal(err)
124+
}
125+
if r := b.String(); r != result {
126+
t.Errorf("expected %q; got %q", result, r)
127+
}
128+
b.Reset()
129+
if err := Export(nil, []int{1, 2, 3}, &b); err != nil {
130+
t.Fatal(err)
131+
}
132+
if r := b.String(); r != result {
133+
t.Errorf("expected %q; got %q", result, r)
134+
}
135+
}

csv/reader.go

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"os"
8+
"reflect"
89
"strings"
910
"sync"
1011
)
@@ -14,8 +15,9 @@ type Reader struct {
1415
*csv.Reader
1516
closer io.Closer
1617

17-
once sync.Once
18-
fields []string
18+
once sync.Once
19+
fields []string
20+
hasFields bool
1921

2022
next []string
2123
nextErr error
@@ -33,6 +35,7 @@ func NewReader(r io.Reader, hasFields bool) (*Reader, error) {
3335
if err != nil {
3436
return nil, err
3537
}
38+
reader.hasFields = true
3639
}
3740
return reader, nil
3841
}
@@ -75,6 +78,7 @@ func (r *Reader) Read() (record []string, err error) {
7578
// SetFields sets csv fields.
7679
func (r *Reader) SetFields(fields []string) {
7780
r.fields = fields
81+
r.hasFields = true
7882
}
7983

8084
// Next prepares the next record for reading with the Scan or Decode method.
@@ -106,22 +110,25 @@ func (r *Reader) Scan(dest ...any) error {
106110
// Decode will unmarshal the current record into dest.
107111
// If column's value is like "[...]", it will be treated as slice.
108112
func (r *Reader) Decode(dest any) error {
109-
if len(r.fields) == 0 {
110-
return fmt.Errorf("csv fields is not parsed")
111-
}
112-
if r.next == nil && r.nextErr == nil {
113-
return fmt.Errorf("Decode called without calling Next")
114-
}
115-
if r.nextErr != nil {
116-
return r.nextErr
117-
}
118-
m := make(map[string]string)
119-
for i, field := range r.fields {
120-
if len(r.next) > i {
121-
m[field] = r.next[i]
113+
if r.hasFields {
114+
if len(r.fields) == 0 {
115+
return fmt.Errorf("csv fields is not parsed")
122116
}
117+
if r.next == nil && r.nextErr == nil {
118+
return fmt.Errorf("Decode called without calling Next")
119+
}
120+
if r.nextErr != nil {
121+
return r.nextErr
122+
}
123+
m := make(map[string]string)
124+
for i, field := range r.fields {
125+
if len(r.next) > i {
126+
m[field] = r.next[i]
127+
}
128+
}
129+
return setRow(dest, m)
123130
}
124-
return setRow(dest, m)
131+
return setCell(dest, r.next[0])
125132
}
126133

127134
// Close closes the underlying reader if it implements the io.Closer interface.
@@ -134,7 +141,16 @@ func (r *Reader) Close() error {
134141

135142
// DecodeAll decodes each record from r into dest.
136143
func DecodeAll[S ~[]E, E any](r io.Reader, dest *S) (err error) {
137-
reader, err := NewReader(r, true)
144+
t := reflect.TypeFor[E]()
145+
for t.Kind() == reflect.Pointer {
146+
t = t.Elem()
147+
}
148+
var reader *Reader
149+
if kind := t.Kind(); kind == reflect.Struct || kind == reflect.Map {
150+
reader, err = NewReader(r, true)
151+
} else {
152+
reader, err = NewReader(r, false)
153+
}
138154
if err != nil {
139155
return
140156
}

csv/reader_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,24 @@ b,2,"[3,4]"
7373
t.Errorf("expected %q; got %q", expect, res)
7474
}
7575
}
76+
77+
func TestDecodeSlice(t *testing.T) {
78+
csv := `1
79+
2
80+
3
81+
`
82+
var s1 []string
83+
if err := DecodeAll(strings.NewReader(csv), &s1); err != nil {
84+
t.Fatal(err)
85+
}
86+
if expect := []string{"1", "2", "3"}; !reflect.DeepEqual(expect, s1) {
87+
t.Errorf("expected %v; got %v", expect, s1)
88+
}
89+
var s2 []int
90+
if err := DecodeAll(strings.NewReader(csv), &s2); err != nil {
91+
t.Fatal(err)
92+
}
93+
if expect := []int{1, 2, 3}; !reflect.DeepEqual(expect, s2) {
94+
t.Errorf("expected %v; got %v", expect, s2)
95+
}
96+
}

csv/writer.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ func (w *Writer) WriteFields(fields any) error {
5151
}
5252
default:
5353
v := reflect.ValueOf(fields)
54-
if v.Kind() == reflect.Pointer {
55-
v = reflect.Indirect(v)
54+
for v.Kind() == reflect.Pointer {
55+
v = v.Elem()
5656
if !v.IsValid() {
5757
return fmt.Errorf("can not get fieldnames from nil pointer struct")
5858
}
@@ -107,6 +107,8 @@ func (w *Writer) Write(record any) error {
107107
return fmt.Errorf("fieldnames has not be written yet")
108108
}
109109
switch d := record.(type) {
110+
case string:
111+
return w.Writer.Write([]string{d})
110112
case []string:
111113
if len(d) == 0 {
112114
return nil
@@ -117,15 +119,19 @@ func (w *Writer) Write(record any) error {
117119
if v.Kind() == reflect.Interface {
118120
v = v.Elem()
119121
}
120-
if v.Kind() == reflect.Pointer {
121-
v = reflect.Indirect(v)
122+
for v.Kind() == reflect.Pointer {
123+
v = v.Elem()
122124
if !v.IsValid() {
123125
return nil
124126
}
125127
}
126128
r := w.pool.Get()
127129
defer w.pool.Put(r)
128130
switch v.Kind() {
131+
case reflect.Slice:
132+
for i := range v.Len() {
133+
(*r)[i], _ = marshalText(v.Index(i).Interface())
134+
}
129135
case reflect.Map:
130136
if keyType := reflect.TypeOf(v.Interface()).Key(); keyType.Kind() == reflect.String {
131137
for i, field := range w.fields {
@@ -163,7 +169,7 @@ func (w *Writer) Write(record any) error {
163169
}
164170
}
165171
default:
166-
return fmt.Errorf("not support record format: %s", v.Kind())
172+
(*r)[0], _ = marshalText(v.Interface())
167173
}
168174
if slices.Equal(*r, w.zero) {
169175
return nil
@@ -175,9 +181,15 @@ func (w *Writer) Write(record any) error {
175181
// WriteAll writes multiple CSV records to w using Write and then calls Flush, returning any error from the Flush.
176182
func (w *Writer) WriteAll(records any) error {
177183
switch s := records.(type) {
184+
case []string:
185+
for _, i := range s {
186+
if err := w.Writer.Write([]string{i}); err != nil {
187+
return err
188+
}
189+
}
178190
case [][]string:
179191
for _, i := range s {
180-
if err := w.Write(i); err != nil {
192+
if err := w.Writer.Write(i); err != nil {
181193
return err
182194
}
183195
}

0 commit comments

Comments
 (0)