diff --git a/arrow_batch.go b/arrow_batch.go new file mode 100644 index 000000000..8d326afed --- /dev/null +++ b/arrow_batch.go @@ -0,0 +1,25 @@ +//go:build !nobatch +// +build !nobatch + +package gosnowflake + +import ( + "github.com/apache/arrow-go/v18/arrow" +) + +func (arc *arrowResultChunk) decodeArrowBatch(scd *snowflakeChunkDownloader) (*[]arrow.Record, error) { + var records []arrow.Record + defer arc.reader.Release() + + for arc.reader.Next() { + rawRecord := arc.reader.Record() + + record, err := arrowToRecord(scd.ctx, rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc) + if err != nil { + return nil, err + } + records = append(records, record) + } + + return &records, arc.reader.Err() +} diff --git a/arrow_chunk.go b/arrow_chunk.go index 1a94916a3..cb8e605ff 100644 --- a/arrow_chunk.go +++ b/arrow_chunk.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "time" - "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" ) @@ -51,23 +50,6 @@ func (arc *arrowResultChunk) decodeArrowChunk(ctx context.Context, rowType []exe return chunkRows, arc.reader.Err() } -func (arc *arrowResultChunk) decodeArrowBatch(scd *snowflakeChunkDownloader) (*[]arrow.Record, error) { - var records []arrow.Record - defer arc.reader.Release() - - for arc.reader.Next() { - rawRecord := arc.reader.Record() - - record, err := arrowToRecord(scd.ctx, rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc) - if err != nil { - return nil, err - } - records = append(records, record) - } - - return &records, arc.reader.Err() -} - // Build arrow chunk based on RowSet of base64 func buildFirstArrowChunk(rowsetBase64 string, loc *time.Location, alloc memory.Allocator) (arrowResultChunk, error) { rowSetBytes, err := base64.StdEncoding.DecodeString(rowsetBase64) diff --git a/batch_downloader.go b/batch_downloader.go new file mode 100644 index 000000000..e53d43ae7 --- /dev/null +++ b/batch_downloader.go @@ -0,0 +1,811 @@ +//go:build !nobatch +// +build !nobatch + +package gosnowflake + +import ( + "bufio" + "compress/gzip" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +var ( + errNoConnection = errors.New("failed to retrieve connection") +) + +type chunkDownloader interface { + totalUncompressedSize() (acc int64) + hasNextResultSet() bool + nextResultSet() error + start() error + next() (chunkRowType, error) + reset() + getChunkMetas() []execResponseChunk + getQueryResultFormat() resultFormat + getRowType() []execResponseRowType + setNextChunkDownloader(downloader chunkDownloader) + getNextChunkDownloader() chunkDownloader + getArrowBatches() []*ArrowBatch +} + +type snowflakeChunkDownloader struct { + sc *snowflakeConn + ctx context.Context + pool memory.Allocator + Total int64 + TotalRowIndex int64 + CellCount int + CurrentChunk []chunkRowType + CurrentChunkIndex int + CurrentChunkSize int + CurrentIndex int + ChunkHeader map[string]string + ChunkMetas []execResponseChunk + Chunks map[int][]chunkRowType + ChunksChan chan int + ChunksError chan *chunkError + ChunksErrorCounter int + ChunksFinalErrors []*chunkError + ChunksMutex *sync.Mutex + DoneDownloadCond *sync.Cond + FirstBatch *ArrowBatch + NextDownloader chunkDownloader + Qrmk string + QueryResultFormat string + ArrowBatches []*ArrowBatch + RowSet rowSetType + FuncDownload func(context.Context, *snowflakeChunkDownloader, int) + FuncDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error + FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) +} + +func (scd *snowflakeChunkDownloader) totalUncompressedSize() (acc int64) { + for _, c := range scd.ChunkMetas { + acc += c.UncompressedSize + } + return +} + +func (scd *snowflakeChunkDownloader) hasNextResultSet() bool { + if len(scd.ChunkMetas) == 0 && scd.NextDownloader == nil { + return false // no extra chunk + } + // next result set exists if current chunk has remaining result sets or there is another downloader + return scd.CurrentChunkIndex < len(scd.ChunkMetas) || scd.NextDownloader != nil +} + +func (scd *snowflakeChunkDownloader) nextResultSet() error { + // no error at all times as the next chunk/resultset is automatically read + if scd.CurrentChunkIndex < len(scd.ChunkMetas) { + return nil + } + return io.EOF +} + +func (scd *snowflakeChunkDownloader) start() error { + if usesArrowBatches(scd.ctx) && scd.getQueryResultFormat() == arrowFormat { + return scd.startArrowBatches() + } + scd.CurrentChunkSize = len(scd.RowSet.JSON) // cache the size + scd.CurrentIndex = -1 // initial chunks idx + scd.CurrentChunkIndex = -1 // initial chunk + + scd.CurrentChunk = make([]chunkRowType, scd.CurrentChunkSize) + populateJSONRowSet(scd.CurrentChunk, scd.RowSet.JSON) + + if scd.getQueryResultFormat() == arrowFormat && scd.RowSet.RowSetBase64 != "" { + params, err := scd.getConfigParams() + if err != nil { + return fmt.Errorf("getting config params: %w", err) + } + // if the rowsetbase64 retrieved from the server is empty, move on to downloading chunks + loc := getCurrentLocation(params) + firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) + if err != nil { + return fmt.Errorf("building first arrow chunk: %w", err) + } + higherPrecision := higherPrecisionEnabled(scd.ctx) + scd.CurrentChunk, err = firstArrowChunk.decodeArrowChunk(scd.ctx, scd.RowSet.RowType, higherPrecision, params) + scd.CurrentChunkSize = firstArrowChunk.rowCount + if err != nil { + return fmt.Errorf("decoding arrow chunk: %w", err) + } + } + + // start downloading chunks if exists + chunkMetaLen := len(scd.ChunkMetas) + if chunkMetaLen > 0 { + logger.WithContext(scd.ctx).Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers) + logger.WithContext(scd.ctx).Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize()) + scd.ChunksMutex = &sync.Mutex{} + scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) + scd.Chunks = make(map[int][]chunkRowType) + scd.ChunksChan = make(chan int, chunkMetaLen) + scd.ChunksError = make(chan *chunkError, MaxChunkDownloadWorkers) + for i := 0; i < chunkMetaLen; i++ { + chunk := scd.ChunkMetas[i] + logger.WithContext(scd.ctx).Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v", + i+1, chunk.URL, chunk.RowCount, chunk.UncompressedSize, scd.QueryResultFormat) + scd.ChunksChan <- i + } + for i := 0; i < intMin(MaxChunkDownloadWorkers, chunkMetaLen); i++ { + scd.schedule() + } + } + return nil +} + +func (scd *snowflakeChunkDownloader) schedule() { + select { + case nextIdx := <-scd.ChunksChan: + logger.WithContext(scd.ctx).Infof("schedule chunk: %v", nextIdx+1) + go GoroutineWrapper( + scd.ctx, + func() { + scd.FuncDownload(scd.ctx, scd, nextIdx) + }, + ) + default: + // no more download + logger.WithContext(scd.ctx).Info("no more download") + } +} + +func (scd *snowflakeChunkDownloader) checkErrorRetry() error { + select { + case errc := <-scd.ChunksError: + if scd.ChunksErrorCounter >= maxChunkDownloaderErrorCounter || + errors.Is(errc.Error, context.Canceled) || + errors.Is(errc.Error, context.DeadlineExceeded) { + + scd.ChunksFinalErrors = append(scd.ChunksFinalErrors, errc) + logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error) + return errc.Error + } + + // add the index to the chunks channel so that the download will be retried. + go GoroutineWrapper( + scd.ctx, + func() { + scd.FuncDownload(scd.ctx, scd, errc.Index) + }, + ) + scd.ChunksErrorCounter++ + logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...", + errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter) + return nil + default: + logger.WithContext(scd.ctx).Info("no error is detected.") + return nil + } +} + +func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) { + for { + scd.CurrentIndex++ + if scd.CurrentIndex < scd.CurrentChunkSize { + return scd.CurrentChunk[scd.CurrentIndex], nil + } + scd.CurrentChunkIndex++ // next chunk + scd.CurrentIndex = -1 // reset + if scd.CurrentChunkIndex >= len(scd.ChunkMetas) { + break + } + + scd.ChunksMutex.Lock() + if scd.CurrentChunkIndex > 0 { + scd.Chunks[scd.CurrentChunkIndex-1] = nil // detach the previously used chunk + } + + for scd.Chunks[scd.CurrentChunkIndex] == nil { + logger.WithContext(scd.ctx).Debugf("waiting for chunk idx: %v/%v", + scd.CurrentChunkIndex+1, len(scd.ChunkMetas)) + + if err := scd.checkErrorRetry(); err != nil { + scd.ChunksMutex.Unlock() + return chunkRowType{}, fmt.Errorf("checking for error: %w", err) + } + + // wait for chunk downloader goroutine to broadcast the event, + // 1) one chunk download finishes or 2) an error occurs. + scd.DoneDownloadCond.Wait() + } + logger.WithContext(scd.ctx).Debugf("ready: chunk %v", scd.CurrentChunkIndex+1) + scd.CurrentChunk = scd.Chunks[scd.CurrentChunkIndex] + scd.ChunksMutex.Unlock() + scd.CurrentChunkSize = len(scd.CurrentChunk) + + // kick off the next download + scd.schedule() + } + + logger.WithContext(scd.ctx).Debugf("no more data") + if len(scd.ChunkMetas) > 0 { + close(scd.ChunksError) + close(scd.ChunksChan) + } + return chunkRowType{}, io.EOF +} + +func (scd *snowflakeChunkDownloader) reset() { + scd.Chunks = nil // detach all chunks. No way to go backward without reinitialize it. +} + +func (scd *snowflakeChunkDownloader) getChunkMetas() []execResponseChunk { + return scd.ChunkMetas +} + +func (scd *snowflakeChunkDownloader) getQueryResultFormat() resultFormat { + return resultFormat(scd.QueryResultFormat) +} + +func (scd *snowflakeChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) { + scd.NextDownloader = nextDownloader +} + +func (scd *snowflakeChunkDownloader) getNextChunkDownloader() chunkDownloader { + return scd.NextDownloader +} + +func (scd *snowflakeChunkDownloader) getRowType() []execResponseRowType { + return scd.RowSet.RowType +} + +func (scd *snowflakeChunkDownloader) getArrowBatches() []*ArrowBatch { + if scd.FirstBatch == nil || scd.FirstBatch.rec == nil { + return scd.ArrowBatches + } + return append([]*ArrowBatch{scd.FirstBatch}, scd.ArrowBatches...) +} + +func (scd *snowflakeChunkDownloader) getConfigParams() (map[string]*string, error) { + if scd.sc == nil || scd.sc.cfg == nil { + return map[string]*string{}, errNoConnection + } + return scd.sc.cfg.Params, nil +} + +func getChunk( + ctx context.Context, + sc *snowflakeConn, + fullURL string, + headers map[string]string, + timeout time.Duration) ( + *http.Response, error, +) { + u, err := url.Parse(fullURL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute() +} + +func (scd *snowflakeChunkDownloader) startArrowBatches() error { + var loc *time.Location + params, err := scd.getConfigParams() + if err != nil { + return fmt.Errorf("getting config params: %w", err) + } + loc = getCurrentLocation(params) + if scd.RowSet.RowSetBase64 != "" { + firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) + if err != nil { + return fmt.Errorf("building first arrow chunk: %w", err) + } + scd.FirstBatch = &ArrowBatch{ + idx: 0, + scd: scd, + funcDownloadHelper: scd.FuncDownloadHelper, + loc: loc, + } + // decode first chunk if possible + if firstArrowChunk.allocator != nil { + scd.FirstBatch.rec, err = firstArrowChunk.decodeArrowBatch(scd) + if err != nil { + return fmt.Errorf("decoding arrow batch: %w", err) + } + } + } + chunkMetaLen := len(scd.ChunkMetas) + scd.ArrowBatches = make([]*ArrowBatch, chunkMetaLen) + for i := range scd.ArrowBatches { + scd.ArrowBatches[i] = &ArrowBatch{ + idx: i, + scd: scd, + funcDownloadHelper: scd.FuncDownloadHelper, + loc: loc, + } + } + return nil +} + +/* largeResultSetReader is a reader that wraps the large result set with leading and tailing brackets. */ +type largeResultSetReader struct { + status int + body io.Reader +} + +func (r *largeResultSetReader) Read(p []byte) (n int, err error) { + if r.status == 0 { + p[0] = 0x5b // initial 0x5b ([) + r.status = 1 + return 1, nil + } + if r.status == 1 { + var len int + len, err = r.body.Read(p) + if err == io.EOF { + r.status = 2 + return len, nil + } + if err != nil { + return 0, fmt.Errorf("reading body: %w", err) + } + return len, nil + } + if r.status == 2 { + p[0] = 0x5d // tail 0x5d (]) + r.status = 3 + return 1, nil + } + // ensure no data and EOF + return 0, io.EOF +} + +func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { + logger.WithContext(ctx).Infof("download start chunk: %v", idx+1) + defer scd.DoneDownloadCond.Broadcast() + + if err := scd.FuncDownloadHelper(ctx, scd, idx); err != nil { + logger.WithContext(ctx).Errorf( + "failed to extract HTTP response body. URL: %v, err: %v", scd.ChunkMetas[idx].URL, err) + scd.ChunksError <- &chunkError{Index: idx, Error: err} + } else if errors.Is(scd.ctx.Err(), context.Canceled) || errors.Is(scd.ctx.Err(), context.DeadlineExceeded) { + scd.ChunksError <- &chunkError{Index: idx, Error: scd.ctx.Err()} + } +} + +func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx int) error { + headers := make(map[string]string) + if len(scd.ChunkHeader) > 0 { + logger.WithContext(ctx).Debug("chunk header is provided.") + for k, v := range scd.ChunkHeader { + logger.WithContext(ctx).Debugf("adding header: %v, value: %v", k, v) + + headers[k] = v + } + } else { + headers[headerSseCAlgorithm] = headerSseCAes + headers[headerSseCKey] = scd.Qrmk + } + + resp, err := scd.FuncGet(ctx, scd.sc, scd.ChunkMetas[idx].URL, headers, scd.sc.rest.RequestTimeout) + if err != nil { + return fmt.Errorf("getting chunk: %w", err) + } + defer resp.Body.Close() + logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL) + if resp.StatusCode != http.StatusOK { + b, err := io.ReadAll(resp.Body) + if err != nil { + logger.WithContext(ctx).Warnf("reading response body: %v", err) + } + logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b) + logger.WithContext(ctx).Infof("Header: %v", resp.Header) + return &SnowflakeError{ + Number: ErrFailedToGetChunk, + SQLState: SQLStateConnectionFailure, + Message: errMsgFailedToGetChunk, + MessageArgs: []any{idx}, + } + } + + bufStream := bufio.NewReader(resp.Body) + return decodeChunk(ctx, scd, idx, bufStream) +} + +func decodeChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int, bufStream *bufio.Reader) error { + gzipMagic, err := bufStream.Peek(2) + if err != nil { + return fmt.Errorf("peeking for gzip magic bytes: %w", err) + } + start := time.Now() + var source io.Reader + if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { + // detects and uncompresses Gzip format data + bufStream0, err := gzip.NewReader(bufStream) + if err != nil { + return fmt.Errorf("creating gzip reader: %w", err) + } + defer bufStream0.Close() + source = bufStream0 + } else { + source = bufStream + } + st := &largeResultSetReader{ + status: 0, + body: source, + } + var respd []chunkRowType + if scd.getQueryResultFormat() != arrowFormat { + var decRespd [][]*string + if !CustomJSONDecoderEnabled { + dec := json.NewDecoder(st) + for { + if err := dec.Decode(&decRespd); err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("decoding json: %w", err) + } + } + } else { + decRespd, err = decodeLargeChunk(st, scd.ChunkMetas[idx].RowCount, scd.CellCount) + if err != nil { + return fmt.Errorf("decoding large chunk: %w", err) + } + } + respd = make([]chunkRowType, len(decRespd)) + populateJSONRowSet(respd, decRespd) + } else { + ipcReader, err := ipc.NewReader(source, ipc.WithAllocator(scd.pool)) + if err != nil { + return fmt.Errorf("creating ipc reader: %w", err) + } + var loc *time.Location + params, err := scd.getConfigParams() + if err != nil { + return fmt.Errorf("getting config params: %w", err) + } + loc = getCurrentLocation(params) + arc := arrowResultChunk{ + ipcReader, + 0, + loc, + scd.pool, + } + if usesArrowBatches(scd.ctx) { + var err error + scd.ArrowBatches[idx].rec, err = arc.decodeArrowBatch(scd) + if err != nil { + return fmt.Errorf("decoding Arrow batch: %w", err) + } + // updating metadata + scd.ArrowBatches[idx].rowCount = countArrowBatchRows(scd.ArrowBatches[idx].rec) + return nil + } + highPrec := higherPrecisionEnabled(scd.ctx) + respd, err = arc.decodeArrowChunk(ctx, scd.RowSet.RowType, highPrec, params) + if err != nil { + return fmt.Errorf("decoding arrow chunk: %w", err) + } + } + logger.WithContext(scd.ctx).Debugf( + "decoded %d rows w/ %d bytes in %s (chunk %v)", + scd.ChunkMetas[idx].RowCount, + scd.ChunkMetas[idx].UncompressedSize, + time.Since(start), idx+1, + ) + + scd.ChunksMutex.Lock() + defer scd.ChunksMutex.Unlock() + scd.Chunks[idx] = respd + return nil +} + +func populateJSONRowSet(dst []chunkRowType, src [][]*string) { + // populate string rowset from src to dst's chunkRowType struct's RowSet field + for i, row := range src { + dst[i].RowSet = row + } +} + +type streamChunkDownloader struct { + ctx context.Context + id int64 + fetcher streamChunkFetcher + readErr error + rowStream chan []*string + Total int64 + ChunkMetas []execResponseChunk + NextDownloader chunkDownloader + RowSet rowSetType +} + +func (scd *streamChunkDownloader) totalUncompressedSize() (acc int64) { + return -1 +} + +func (scd *streamChunkDownloader) hasNextResultSet() bool { + return scd.readErr == nil +} + +func (scd *streamChunkDownloader) nextResultSet() error { + return scd.readErr +} + +func (scd *streamChunkDownloader) start() error { + go GoroutineWrapper( + scd.ctx, + func() { + readErr := io.EOF + + logger.WithContext(scd.ctx).Infof( + "start downloading. downloader id: %v, %v/%v rows, %v chunks", + scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas)) + t := time.Now() + + defer func() { + if readErr == io.EOF { + logger.WithContext(scd.ctx).Infof("downloading done. downloader id: %v", scd.id) + } else { + logger.WithContext(scd.ctx).Debugf("downloading error. downloader id: %v", scd.id) + } + scd.readErr = readErr + close(scd.rowStream) + + if r := recover(); r != nil { + if err, ok := r.(error); ok { + readErr = err + } else { + readErr = fmt.Errorf("%v", r) + } + } + }() + + logger.WithContext(scd.ctx).Infof("sending initial set of rows in %vms", time.Since(t).Microseconds()) + t = time.Now() + for _, row := range scd.RowSet.JSON { + scd.rowStream <- row + } + scd.RowSet.JSON = nil + + // Download and parse one chunk at a time. The fetcher will send each + // parsed row to the row stream. When an error occurs, the fetcher will + // stop writing to the row stream so we can stop processing immediately + for i, chunk := range scd.ChunkMetas { + logger.WithContext(scd.ctx).Infof("starting chunk fetch %d (%d rows)", i, chunk.RowCount) + if err := scd.fetcher.fetch(chunk.URL, scd.rowStream); err != nil { + logger.WithContext(scd.ctx).Debugf( + "failed chunk fetch %d: %#v, downloader id: %v, %v/%v rows, %v chunks", + i, err, scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas)) + readErr = fmt.Errorf("chunk fetch: %w", err) + break + } + logger.WithContext(scd.ctx).Infof("fetched chunk %d (%d rows) in %vms", i, chunk.RowCount, time.Since(t).Microseconds()) + t = time.Now() + } + }, + ) + return nil +} + +func (scd *streamChunkDownloader) next() (chunkRowType, error) { + if row, ok := <-scd.rowStream; ok { + return chunkRowType{RowSet: row}, nil + } + return chunkRowType{}, scd.readErr +} + +func (scd *streamChunkDownloader) reset() {} + +func (scd *streamChunkDownloader) getChunkMetas() []execResponseChunk { + return scd.ChunkMetas +} + +func (scd *streamChunkDownloader) getQueryResultFormat() resultFormat { + return jsonFormat +} + +func (scd *streamChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) { + scd.NextDownloader = nextDownloader +} + +func (scd *streamChunkDownloader) getNextChunkDownloader() chunkDownloader { + return scd.NextDownloader +} + +func (scd *streamChunkDownloader) getRowType() []execResponseRowType { + return scd.RowSet.RowType +} + +func (scd *streamChunkDownloader) getArrowBatches() []*ArrowBatch { + return nil +} + +func useStreamDownloader(ctx context.Context) bool { + val := ctx.Value(streamChunkDownload) + if val == nil { + return false + } + s, ok := val.(bool) + return s && ok +} + +type streamChunkFetcher interface { + fetch(url string, rows chan<- []*string) error +} + +type httpStreamChunkFetcher struct { + ctx context.Context + client *http.Client + clientIP net.IP + headers map[string]string + qrmk string +} + +func newStreamChunkDownloader( + ctx context.Context, + fetcher streamChunkFetcher, + total int64, + rowType []execResponseRowType, + firstRows [][]*string, + chunks []execResponseChunk, +) *streamChunkDownloader { + return &streamChunkDownloader{ + ctx: ctx, + id: rand.Int63(), + fetcher: fetcher, + readErr: nil, + rowStream: make(chan []*string), + Total: total, + ChunkMetas: chunks, + RowSet: rowSetType{RowType: rowType, JSON: firstRows}, + } +} + +func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error { + if len(f.headers) == 0 { + f.headers = map[string]string{ + headerSseCAlgorithm: headerSseCAes, + headerSseCKey: f.qrmk, + } + } + + fullURL, err := url.Parse(URL) + if err != nil { + return fmt.Errorf("parsing URL: %w", err) + } + res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, 0, defaultTimeProvider, nil).execute() + if err != nil { + return fmt.Errorf("executing HTTP request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + b, err := io.ReadAll(res.Body) + if err != nil { + logger.Warnf("httpStreamChunkFetcher.fetch: reading response body: %v", err) + } + return fmt.Errorf("status (%d): %s", res.StatusCode, string(b)) + } + if err = copyChunkStream(res.Body, rows); err != nil { + return fmt.Errorf("copying chunk stream: %w", err) + } + return nil +} + +func copyChunkStream(body io.Reader, rows chan<- []*string) error { + bufStream := bufio.NewReader(body) + gzipMagic, err := bufStream.Peek(2) + if err != nil { + return fmt.Errorf("peeking for gzip magic bytes: %w", err) + } + var source io.Reader = bufStream + if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { + // detect and decompress Gzip format data + bufStream0, err := gzip.NewReader(bufStream) + if err != nil { + return fmt.Errorf("creating gzip reader: %w", err) + } + defer bufStream0.Close() + source = bufStream0 + } + r := io.MultiReader(strings.NewReader("["), source, strings.NewReader("]")) + dec := json.NewDecoder(r) + openToken := json.Delim('[') + closeToken := json.Delim(']') + for { + if t, err := dec.Token(); err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("delim open: %w", err) + } else if t != openToken { + return fmt.Errorf("delim open: got %T", t) + } + + for dec.More() { + var row []*string + if err := dec.Decode(&row); err != nil { + return fmt.Errorf("decoding row: %w", err) + } + rows <- row + } + + if t, err := dec.Token(); err != nil { + return fmt.Errorf("delim close: %w", err) + } else if t != closeToken { + return fmt.Errorf("delim close: got %T", t) + } + } + return nil +} + +// ArrowBatch object represents a chunk of data, or subset of rows, retrievable in arrow.Record format +type ArrowBatch struct { + rec *[]arrow.Record + idx int + rowCount int + scd *snowflakeChunkDownloader + funcDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error + ctx context.Context + loc *time.Location +} + +// WithContext sets the context which will be used for this ArrowBatch. +func (rb *ArrowBatch) WithContext(ctx context.Context) *ArrowBatch { + rb.ctx = ctx + return rb +} + +// Fetch returns an array of records representing a chunk in the query +func (rb *ArrowBatch) Fetch() (*[]arrow.Record, error) { + // chunk has already been downloaded + if rb.rec != nil { + // updating metadata + rb.rowCount = countArrowBatchRows(rb.rec) + return rb.rec, nil + } + var ctx context.Context + if rb.ctx != nil { + ctx = rb.ctx + } else { + ctx = context.Background() + } + if err := rb.funcDownloadHelper(ctx, rb.scd, rb.idx); err != nil { + return nil, fmt.Errorf("running download helper: %w", err) + } + return rb.rec, nil +} + +// GetRowCount returns the number of rows in an arrow batch +func (rb *ArrowBatch) GetRowCount() int { + return rb.rowCount +} + +func getAllocator(ctx context.Context) memory.Allocator { + pool, ok := ctx.Value(arrowAlloc).(memory.Allocator) + if !ok { + return memory.DefaultAllocator + } + return pool +} + +func usesArrowBatches(ctx context.Context) bool { + val := ctx.Value(arrowBatches) + if val == nil { + return false + } + a, ok := val.(bool) + return a && ok +} + +func countArrowBatchRows(recs *[]arrow.Record) (cnt int) { + for _, r := range *recs { + cnt += int(r.NumRows()) + } + return +} diff --git a/chunk_downloader.go b/chunk_downloader.go index 45503beee..bf594371f 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -1,3 +1,6 @@ +//go:build nobatch +// +build nobatch + package gosnowflake import ( @@ -95,9 +98,6 @@ func (scd *snowflakeChunkDownloader) nextResultSet() error { } func (scd *snowflakeChunkDownloader) start() error { - if usesArrowBatches(scd.ctx) && scd.getQueryResultFormat() == arrowFormat { - return scd.startArrowBatches() - } scd.CurrentChunkSize = len(scd.RowSet.JSON) // cache the size scd.CurrentIndex = -1 // initial chunks idx scd.CurrentChunkIndex = -1 // initial chunk @@ -292,45 +292,6 @@ func getChunk( return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute() } -func (scd *snowflakeChunkDownloader) startArrowBatches() error { - var loc *time.Location - params, err := scd.getConfigParams() - if err != nil { - return fmt.Errorf("getting config params: %w", err) - } - loc = getCurrentLocation(params) - if scd.RowSet.RowSetBase64 != "" { - firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) - if err != nil { - return fmt.Errorf("building first arrow chunk: %w", err) - } - scd.FirstBatch = &ArrowBatch{ - idx: 0, - scd: scd, - funcDownloadHelper: scd.FuncDownloadHelper, - loc: loc, - } - // decode first chunk if possible - if firstArrowChunk.allocator != nil { - scd.FirstBatch.rec, err = firstArrowChunk.decodeArrowBatch(scd) - if err != nil { - return fmt.Errorf("decoding arrow batch: %w", err) - } - } - } - chunkMetaLen := len(scd.ChunkMetas) - scd.ArrowBatches = make([]*ArrowBatch, chunkMetaLen) - for i := range scd.ArrowBatches { - scd.ArrowBatches[i] = &ArrowBatch{ - idx: i, - scd: scd, - funcDownloadHelper: scd.FuncDownloadHelper, - loc: loc, - } - } - return nil -} - /* largeResultSetReader is a reader that wraps the large result set with leading and tailing brackets. */ type largeResultSetReader struct { status int @@ -475,16 +436,6 @@ func decodeChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int, bu loc, scd.pool, } - if usesArrowBatches(scd.ctx) { - var err error - scd.ArrowBatches[idx].rec, err = arc.decodeArrowBatch(scd) - if err != nil { - return fmt.Errorf("decoding Arrow batch: %w", err) - } - // updating metadata - scd.ArrowBatches[idx].rowCount = countArrowBatchRows(scd.ArrowBatches[idx].rec) - return nil - } highPrec := higherPrecisionEnabled(scd.ctx) respd, err = arc.decodeArrowChunk(ctx, scd.RowSet.RowType, highPrec, params) if err != nil { @@ -791,15 +742,6 @@ func getAllocator(ctx context.Context) memory.Allocator { return pool } -func usesArrowBatches(ctx context.Context) bool { - val := ctx.Value(arrowBatches) - if val == nil { - return false - } - a, ok := val.(bool) - return a && ok -} - func countArrowBatchRows(recs *[]arrow.Record) (cnt int) { for _, r := range *recs { cnt += int(r.NumRows()) diff --git a/connection.go b/connection.go index f3b7e4445..e55ddb789 100644 --- a/connection.go +++ b/connection.go @@ -1,13 +1,9 @@ package gosnowflake import ( - "bufio" - "bytes" - "compress/gzip" "context" "database/sql" "database/sql/driver" - "encoding/base64" "encoding/json" "io" "net/http" @@ -16,11 +12,8 @@ import ( "strconv" "sync" "sync/atomic" - "time" "go.opentelemetry.io/otel/propagation" - - "github.com/apache/arrow-go/v18/arrow/ipc" ) const ( @@ -515,63 +508,6 @@ func (sc *snowflakeConn) GetQueryStatus( }, nil } -// QueryArrowStream returns batches which can be queried for their raw arrow -// ipc stream of bytes. This way consumers don't need to be using the exact -// same version of Arrow as the connection is using internally in order -// to consume Arrow data. -func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bindings ...driver.NamedValue) (ArrowStreamLoader, error) { - ctx = WithArrowBatches(context.WithValue(ctx, asyncMode, false)) - ctx = setResultType(ctx, queryResultType) - isDesc := isDescribeOnly(ctx) - isInternal := isInternal(ctx) - data, err := sc.exec(ctx, query, false, isInternal, isDesc, bindings) - if err != nil { - logger.WithContext(ctx).Errorf("error: %v", err) - if data != nil { - code, e := strconv.Atoi(data.Code) - if e != nil { - return nil, e - } - return nil, (&SnowflakeError{ - Number: code, - SQLState: data.Data.SQLState, - Message: err.Error(), - QueryID: data.Data.QueryID, - }).exceptionTelemetry(sc) - } - return nil, err - } - - return &snowflakeArrowStreamChunkDownloader{ - sc: sc, - ChunkMetas: data.Data.Chunks, - Total: data.Data.Total, - Qrmk: data.Data.Qrmk, - ChunkHeader: data.Data.ChunkHeaders, - FuncGet: getChunk, - RowSet: rowSetType{ - RowType: data.Data.RowType, - JSON: data.Data.RowSet, - RowSetBase64: data.Data.RowSetBase64, - }, - }, nil -} - -// ArrowStreamBatch is a type describing a potentially yet-to-be-downloaded -// Arrow IPC stream. Call `GetStream` to download and retrieve an io.Reader -// that can be used with ipc.NewReader to get record batch results. -type ArrowStreamBatch struct { - idx int - numrows int64 - scd *snowflakeArrowStreamChunkDownloader - Loc *time.Location - rr io.ReadCloser -} - -// NumRows returns the total number of rows that the metadata stated should -// be in this stream of record batches. -func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows } - // gzip.Reader.Close does NOT close the underlying reader, so we // need to wrap with wrapReader so that closing will close the // response body (or any other reader that we want to gzip uncompress) @@ -589,195 +525,6 @@ func (w *wrapReader) Close() error { return w.wrapped.Close() } -func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error { - headers := make(map[string]string) - if len(asb.scd.ChunkHeader) > 0 { - logger.WithContext(ctx).Debug("chunk header is provided") - for k, v := range asb.scd.ChunkHeader { - logger.Debugf("adding header: %v, value: %v", k, v) - - headers[k] = v - } - } else { - headers[headerSseCAlgorithm] = headerSseCAes - headers[headerSseCKey] = asb.scd.Qrmk - } - - resp, err := asb.scd.FuncGet(ctx, asb.scd.sc, asb.scd.ChunkMetas[asb.idx].URL, headers, asb.scd.sc.rest.RequestTimeout) - if err != nil { - return err - } - logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL) - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b) - logger.WithContext(ctx).Infof("Header: %v", resp.Header) - return &SnowflakeError{ - Number: ErrFailedToGetChunk, - SQLState: SQLStateConnectionFailure, - Message: errMsgFailedToGetChunk, - MessageArgs: []interface{}{asb.idx}, - } - } - - defer func() { - if asb.rr == nil { - resp.Body.Close() - } - }() - - bufStream := bufio.NewReader(resp.Body) - gzipMagic, err := bufStream.Peek(2) - if err != nil { - return err - } - - if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { - // detect and uncompress gzip - bufStream0, err := gzip.NewReader(bufStream) - if err != nil { - return err - } - // gzip.Reader.Close() does NOT close the underlying - // reader, so we need to wrap it and ensure close will - // close the response body. Otherwise we'll leak it. - asb.rr = &wrapReader{Reader: bufStream0, wrapped: resp.Body} - } else { - asb.rr = &wrapReader{Reader: bufStream, wrapped: resp.Body} - } - return nil -} - -// GetStream returns a stream of bytes consisting of an Arrow IPC Record -// batch stream. Close should be called on the returned stream when done -// to ensure no leaked memory. -func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { - if asb.rr == nil { - if err := asb.downloadChunkStreamHelper(ctx); err != nil { - return nil, err - } - } - - return asb.rr, nil -} - -// ArrowStreamLoader is a convenience interface for downloading -// Snowflake results via multiple Arrow Record Batch streams. -// -// Some queries from Snowflake do not return Arrow data regardless -// of the settings, such as "SHOW WAREHOUSES". In these cases, -// you'll find TotalRows() > 0 but GetBatches returns no batches -// and no errors. In this case, the data is accessible via JSONData -// with the actual types matching up to the metadata in RowTypes. -type ArrowStreamLoader interface { - GetBatches() ([]ArrowStreamBatch, error) - TotalRows() int64 - RowTypes() []execResponseRowType - Location() *time.Location - JSONData() [][]*string -} - -type snowflakeArrowStreamChunkDownloader struct { - sc *snowflakeConn - ChunkMetas []execResponseChunk - Total int64 - Qrmk string - ChunkHeader map[string]string - FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) - RowSet rowSetType -} - -func (scd *snowflakeArrowStreamChunkDownloader) Location() *time.Location { - if scd.sc != nil && scd.sc.cfg != nil { - return getCurrentLocation(scd.sc.cfg.Params) - } - return nil -} -func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.Total } -func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []execResponseRowType { - return scd.RowSet.RowType -} -func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string { - return scd.RowSet.JSON -} - -// the server might have had an empty first batch, check if we can decode -// that first batch, if not we skip it. -func (scd *snowflakeArrowStreamChunkDownloader) maybeFirstBatch() ([]byte, error) { - if scd.RowSet.RowSetBase64 == "" { - return nil, nil - } - - // first batch - rowSetBytes, err := base64.StdEncoding.DecodeString(scd.RowSet.RowSetBase64) - if err != nil { - // match logic in buildFirstArrowChunk - // assume there's no first chunk if we can't decode the base64 string - logger.Warnf("skipping first batch as it is not a valid base64 response. %v", err) - return nil, err - } - - // verify it's a valid ipc stream, otherwise skip it - rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes)) - if err != nil { - logger.Warnf("skipping first batch as it is not a valid IPC stream. %v", err) - return nil, err - } - rr.Release() - - return rowSetBytes, nil -} - -func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamBatch, err error) { - chunkMetaLen := len(scd.ChunkMetas) - loc := scd.Location() - - out = make([]ArrowStreamBatch, chunkMetaLen, chunkMetaLen+1) - toFill := out - rowSetBytes, err := scd.maybeFirstBatch() - if err != nil { - return nil, err - } - // if there was no first batch in the response from the server, - // skip it and move on. toFill == out - // otherwise expand out by one to account for the first batch - // and fill it in. have toFill refer to the slice of out excluding - // the first batch. - if len(rowSetBytes) > 0 { - out = out[:chunkMetaLen+1] - out[0] = ArrowStreamBatch{ - scd: scd, - Loc: loc, - rr: io.NopCloser(bytes.NewReader(rowSetBytes)), - } - toFill = out[1:] - } - - var totalCounted int64 - for i := range toFill { - toFill[i] = ArrowStreamBatch{ - idx: i, - numrows: int64(scd.ChunkMetas[i].RowCount), - Loc: loc, - scd: scd, - } - logger.Debugf("batch %v, numrows: %v", i, toFill[i].numrows) - totalCounted += int64(scd.ChunkMetas[i].RowCount) - } - - if len(rowSetBytes) > 0 { - // if we had a first batch, fill in the numrows - out[0].numrows = scd.Total - totalCounted - logger.Debugf("first batch, numrows: %v", out[0].numrows) - } - return -} - // buildSnowflakeConn creates a new snowflakeConn. // The provided context is used only for establishing the initial connection. func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { diff --git a/connection_batch.go b/connection_batch.go new file mode 100644 index 000000000..1def179da --- /dev/null +++ b/connection_batch.go @@ -0,0 +1,265 @@ +//go:build !nobatch +// +build !nobatch + +package gosnowflake + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "database/sql/driver" + "encoding/base64" + "io" + "net/http" + "strconv" + "time" + + "github.com/apache/arrow-go/v18/arrow/ipc" +) + +// QueryArrowStream returns batches which can be queried for their raw arrow +// ipc stream of bytes. This way consumers don't need to be using the exact +// same version of Arrow as the connection is using internally in order +// to consume Arrow data. +func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bindings ...driver.NamedValue) (ArrowStreamLoader, error) { + ctx = WithArrowBatches(context.WithValue(ctx, asyncMode, false)) + ctx = setResultType(ctx, queryResultType) + isDesc := isDescribeOnly(ctx) + isInternal := isInternal(ctx) + data, err := sc.exec(ctx, query, false, isInternal, isDesc, bindings) + if err != nil { + logger.WithContext(ctx).Errorf("error: %v", err) + if data != nil { + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e + } + return nil, (&SnowflakeError{ + Number: code, + SQLState: data.Data.SQLState, + Message: err.Error(), + QueryID: data.Data.QueryID, + }).exceptionTelemetry(sc) + } + return nil, err + } + + return &snowflakeArrowStreamChunkDownloader{ + sc: sc, + ChunkMetas: data.Data.Chunks, + Total: data.Data.Total, + Qrmk: data.Data.Qrmk, + ChunkHeader: data.Data.ChunkHeaders, + FuncGet: getChunk, + RowSet: rowSetType{ + RowType: data.Data.RowType, + JSON: data.Data.RowSet, + RowSetBase64: data.Data.RowSetBase64, + }, + }, nil +} + +// ArrowStreamBatch is a type describing a potentially yet-to-be-downloaded +// Arrow IPC stream. Call `GetStream` to download and retrieve an io.Reader +// that can be used with ipc.NewReader to get record batch results. +type ArrowStreamBatch struct { + idx int + numrows int64 + scd *snowflakeArrowStreamChunkDownloader + Loc *time.Location + rr io.ReadCloser +} + +// NumRows returns the total number of rows that the metadata stated should +// be in this stream of record batches. +func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows } + +func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error { + headers := make(map[string]string) + if len(asb.scd.ChunkHeader) > 0 { + logger.WithContext(ctx).Debug("chunk header is provided") + for k, v := range asb.scd.ChunkHeader { + logger.Debugf("adding header: %v, value: %v", k, v) + + headers[k] = v + } + } else { + headers[headerSseCAlgorithm] = headerSseCAes + headers[headerSseCKey] = asb.scd.Qrmk + } + + resp, err := asb.scd.FuncGet(ctx, asb.scd.sc, asb.scd.ChunkMetas[asb.idx].URL, headers, asb.scd.sc.rest.RequestTimeout) + if err != nil { + return err + } + logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL) + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b) + logger.WithContext(ctx).Infof("Header: %v", resp.Header) + return &SnowflakeError{ + Number: ErrFailedToGetChunk, + SQLState: SQLStateConnectionFailure, + Message: errMsgFailedToGetChunk, + MessageArgs: []interface{}{asb.idx}, + } + } + + defer func() { + if asb.rr == nil { + resp.Body.Close() + } + }() + + bufStream := bufio.NewReader(resp.Body) + gzipMagic, err := bufStream.Peek(2) + if err != nil { + return err + } + + if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { + // detect and uncompress gzip + bufStream0, err := gzip.NewReader(bufStream) + if err != nil { + return err + } + // gzip.Reader.Close() does NOT close the underlying + // reader, so we need to wrap it and ensure close will + // close the response body. Otherwise we'll leak it. + asb.rr = &wrapReader{Reader: bufStream0, wrapped: resp.Body} + } else { + asb.rr = &wrapReader{Reader: bufStream, wrapped: resp.Body} + } + return nil +} + +// GetStream returns a stream of bytes consisting of an Arrow IPC Record +// batch stream. Close should be called on the returned stream when done +// to ensure no leaked memory. +func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { + if asb.rr == nil { + if err := asb.downloadChunkStreamHelper(ctx); err != nil { + return nil, err + } + } + + return asb.rr, nil +} + +// ArrowStreamLoader is a convenience interface for downloading +// Snowflake results via multiple Arrow Record Batch streams. +// +// Some queries from Snowflake do not return Arrow data regardless +// of the settings, such as "SHOW WAREHOUSES". In these cases, +// you'll find TotalRows() > 0 but GetBatches returns no batches +// and no errors. In this case, the data is accessible via JSONData +// with the actual types matching up to the metadata in RowTypes. +type ArrowStreamLoader interface { + GetBatches() ([]ArrowStreamBatch, error) + TotalRows() int64 + RowTypes() []execResponseRowType + Location() *time.Location + JSONData() [][]*string +} + +type snowflakeArrowStreamChunkDownloader struct { + sc *snowflakeConn + ChunkMetas []execResponseChunk + Total int64 + Qrmk string + ChunkHeader map[string]string + FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) + RowSet rowSetType +} + +func (scd *snowflakeArrowStreamChunkDownloader) Location() *time.Location { + if scd.sc != nil && scd.sc.cfg != nil { + return getCurrentLocation(scd.sc.cfg.Params) + } + return nil +} +func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.Total } +func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []execResponseRowType { + return scd.RowSet.RowType +} +func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string { + return scd.RowSet.JSON +} + +// the server might have had an empty first batch, check if we can decode +// that first batch, if not we skip it. +func (scd *snowflakeArrowStreamChunkDownloader) maybeFirstBatch() ([]byte, error) { + if scd.RowSet.RowSetBase64 == "" { + return nil, nil + } + + // first batch + rowSetBytes, err := base64.StdEncoding.DecodeString(scd.RowSet.RowSetBase64) + if err != nil { + // match logic in buildFirstArrowChunk + // assume there's no first chunk if we can't decode the base64 string + logger.Warnf("skipping first batch as it is not a valid base64 response. %v", err) + return nil, err + } + + // verify it's a valid ipc stream, otherwise skip it + rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes)) + if err != nil { + logger.Warnf("skipping first batch as it is not a valid IPC stream. %v", err) + return nil, err + } + rr.Release() + + return rowSetBytes, nil +} + +func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamBatch, err error) { + chunkMetaLen := len(scd.ChunkMetas) + loc := scd.Location() + + out = make([]ArrowStreamBatch, chunkMetaLen, chunkMetaLen+1) + toFill := out + rowSetBytes, err := scd.maybeFirstBatch() + if err != nil { + return nil, err + } + // if there was no first batch in the response from the server, + // skip it and move on. toFill == out + // otherwise expand out by one to account for the first batch + // and fill it in. have toFill refer to the slice of out excluding + // the first batch. + if len(rowSetBytes) > 0 { + out = out[:chunkMetaLen+1] + out[0] = ArrowStreamBatch{ + scd: scd, + Loc: loc, + rr: io.NopCloser(bytes.NewReader(rowSetBytes)), + } + toFill = out[1:] + } + + var totalCounted int64 + for i := range toFill { + toFill[i] = ArrowStreamBatch{ + idx: i, + numrows: int64(scd.ChunkMetas[i].RowCount), + Loc: loc, + scd: scd, + } + logger.Debugf("batch %v, numrows: %v", i, toFill[i].numrows) + totalCounted += int64(scd.ChunkMetas[i].RowCount) + } + + if len(rowSetBytes) > 0 { + // if we had a first batch, fill in the numrows + out[0].numrows = scd.Total - totalCounted + logger.Debugf("first batch, numrows: %v", out[0].numrows) + } + return +} diff --git a/converter.go b/converter.go index 440a463c5..147c9a37b 100644 --- a/converter.go +++ b/converter.go @@ -16,13 +16,10 @@ import ( "strconv" "strings" "time" - "unicode/utf8" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/arrow-go/v18/arrow/compute" "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/memory" ) const format = "2006-01-02 15:04:05.999999999" @@ -46,21 +43,6 @@ const ( TimeType ) -type snowflakeArrowBatchesTimestampOption int - -const ( - // UseNanosecondTimestamp uses arrow.Timestamp in nanosecond precision, could cause ErrTooHighTimestampPrecision if arrow.Timestamp cannot fit original timestamp values. - UseNanosecondTimestamp snowflakeArrowBatchesTimestampOption = iota - // UseMicrosecondTimestamp uses arrow.Timestamp in microsecond precision - UseMicrosecondTimestamp - // UseMillisecondTimestamp uses arrow.Timestamp in millisecond precision - UseMillisecondTimestamp - // UseSecondTimestamp uses arrow.Timestamp in second precision - UseSecondTimestamp - // UseOriginalTimestamp uses original timestamp struct returned by Snowflake. It can be used in case arrow.Timestamp cannot fit original timestamp values. - UseOriginalTimestamp -) - type interfaceArrayBinding struct { hasTimezone bool tzType timezoneType @@ -1453,13 +1435,6 @@ func decimalToBigFloat(num decimal128.Num, scale int64) *big.Float { return new(big.Float).Quo(f, s) } -// ArrowSnowflakeTimestampToTime converts original timestamp returned by Snowflake to time.Time -func (rb *ArrowBatch) ArrowSnowflakeTimestampToTime(rec arrow.Record, colIdx int, recIdx int) *time.Time { - scale := int(rb.scd.RowSet.RowType[colIdx].Scale) - dbType := rb.scd.RowSet.RowType[colIdx].Type - return arrowSnowflakeTimestampToTime(rec.Column(colIdx), getSnowflakeType(dbType), scale, recIdx, rb.loc) -} - func arrowSnowflakeTimestampToTime( column arrow.Array, sfType snowflakeType, @@ -2862,381 +2837,6 @@ func higherPrecisionEnabled(ctx context.Context) bool { return ok && d } -func arrowBatchesUtf8ValidationEnabled(ctx context.Context) bool { - v := ctx.Value(enableArrowBatchesUtf8Validation) - if v == nil { - return false - } - d, ok := v.(bool) - return ok && d -} - -func getArrowBatchesTimestampOption(ctx context.Context) snowflakeArrowBatchesTimestampOption { - v := ctx.Value(arrowBatchesTimestampOption) - if v == nil { - return UseNanosecondTimestamp - } - o, ok := v.(snowflakeArrowBatchesTimestampOption) - if !ok { - return UseNanosecondTimestamp - } - return o -} - -func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) { - arrowBatchesTimestampOption := getArrowBatchesTimestampOption(ctx) - higherPrecisionEnabled := higherPrecisionEnabled(ctx) - - s, err := recordToSchema(record.Schema(), rowType, loc, arrowBatchesTimestampOption, higherPrecisionEnabled) - if err != nil { - return nil, err - } - - var cols []arrow.Array - numRows := record.NumRows() - ctxAlloc := compute.WithAllocator(ctx, pool) - - for i, col := range record.Columns() { - fieldMetadata := rowType[i].toFieldMetadata() - - newCol, err := arrowToRecordSingleColumn(ctxAlloc, s.Field(i), col, fieldMetadata, higherPrecisionEnabled, arrowBatchesTimestampOption, pool, loc, numRows) - if err != nil { - return nil, err - } - cols = append(cols, newCol) - defer newCol.Release() - } - newRecord := array.NewRecord(s, cols, numRows) - return newRecord, nil -} - -func arrowToRecordSingleColumn(ctx context.Context, field arrow.Field, col arrow.Array, fieldMetadata fieldMetadata, higherPrecisionEnabled bool, timestampOption snowflakeArrowBatchesTimestampOption, pool memory.Allocator, loc *time.Location, numRows int64) (arrow.Array, error) { - var err error - newCol := col - snowflakeType := getSnowflakeType(fieldMetadata.Type) - switch snowflakeType { - case fixedType: - if higherPrecisionEnabled { - // do nothing - return decimal as is - col.Retain() - } else if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 { - var toType arrow.DataType - if fieldMetadata.Scale == 0 { - toType = arrow.PrimitiveTypes.Int64 - } else { - toType = arrow.PrimitiveTypes.Float64 - } - // we're fine truncating so no error for data loss here. - // so we use UnsafeCastOptions. - newCol, err = compute.CastArray(ctx, col, compute.UnsafeCastOptions(toType)) - if err != nil { - return nil, err - } - } else if fieldMetadata.Scale != 0 && col.DataType().ID() != arrow.INT64 { - result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, - &compute.ArrayDatum{Value: newCol.Data()}, - compute.NewDatum(math.Pow10(int(fieldMetadata.Scale)))) - if err != nil { - return nil, err - } - defer result.Release() - newCol = result.(*compute.ArrayDatum).MakeArray() - } else if fieldMetadata.Scale != 0 && col.DataType().ID() == arrow.INT64 { - // gosnowflake driver uses compute.Divide() which could bring `integer value not in range: -9007199254740992 to 9007199254740992` error - // if we convert int64 to BigDecimal and then use compute.CastArray to convert BigDecimal to float64, we won't have enough precision. - // e.g 0.1 as (38,19) will result 0.09999999999999999 - values := col.(*array.Int64).Int64Values() - floatValues := make([]float64, len(values)) - for i, val := range values { - floatValues[i], _ = intToBigFloat(val, int64(fieldMetadata.Scale)).Float64() - } - builder := array.NewFloat64Builder(pool) - builder.AppendValues(floatValues, nil) - newCol = builder.NewArray() - builder.Release() - } else { - col.Retain() - } - case timeType: - newCol, err = compute.CastArray(ctx, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns)) - if err != nil { - return nil, err - } - case timestampNtzType, timestampLtzType, timestampTzType: - if timestampOption == UseOriginalTimestamp { - // do nothing - return timestamp as is - col.Retain() - } else { - var unit arrow.TimeUnit - switch timestampOption { - case UseMicrosecondTimestamp: - unit = arrow.Microsecond - case UseMillisecondTimestamp: - unit = arrow.Millisecond - case UseSecondTimestamp: - unit = arrow.Second - case UseNanosecondTimestamp: - unit = arrow.Nanosecond - } - var tb *array.TimestampBuilder - if snowflakeType == timestampLtzType { - tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit, TimeZone: loc.String()}) - } else { - tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit}) - } - defer tb.Release() - - for i := 0; i < int(numRows); i++ { - ts := arrowSnowflakeTimestampToTime(col, snowflakeType, int(fieldMetadata.Scale), i, loc) - if ts != nil { - var ar arrow.Timestamp - switch timestampOption { - case UseMicrosecondTimestamp: - ar = arrow.Timestamp(ts.UnixMicro()) - case UseMillisecondTimestamp: - ar = arrow.Timestamp(ts.UnixMilli()) - case UseSecondTimestamp: - ar = arrow.Timestamp(ts.Unix()) - case UseNanosecondTimestamp: - ar = arrow.Timestamp(ts.UnixNano()) - // in case of overflow in arrow timestamp return error - // this could only happen for nanosecond case - if ts.UTC().Year() != ar.ToTime(arrow.Nanosecond).Year() { - return nil, &SnowflakeError{ - Number: ErrTooHighTimestampPrecision, - SQLState: SQLStateInvalidDataTimeFormat, - Message: fmt.Sprintf("Cannot convert timestamp %v in column %v to Arrow.Timestamp data type due to too high precision. Please use context with WithOriginalTimestamp.", ts.UTC(), fieldMetadata.Name), - } - } - } - tb.Append(ar) - } else { - tb.AppendNull() - } - } - newCol = tb.NewArray() - } - case textType: - if stringCol, ok := col.(*array.String); ok { - newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) - } - case objectType: - if structCol, ok := col.(*array.Struct); ok { - var internalCols []arrow.Array - for i := 0; i < structCol.NumField(); i++ { - internalCol := structCol.Field(i) - newInternalCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.StructType).Field(i), internalCol, fieldMetadata.Fields[i], higherPrecisionEnabled, timestampOption, pool, loc, numRows) - if err != nil { - return nil, err - } - internalCols = append(internalCols, newInternalCol) - defer newInternalCol.Release() - } - var fieldNames []string - for _, f := range field.Type.(*arrow.StructType).Fields() { - fieldNames = append(fieldNames, f.Name) - } - nullBitmap := memory.NewBufferBytes(structCol.NullBitmapBytes()) - numberOfNulls := structCol.NullN() - return array.NewStructArrayWithNulls(internalCols, fieldNames, nullBitmap, numberOfNulls, 0) - } else if stringCol, ok := col.(*array.String); ok { - newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) - } - case arrayType: - if listCol, ok := col.(*array.List); ok { - newCol, err = arrowToRecordSingleColumn(ctx, field.Type.(*arrow.ListType).ElemField(), listCol.ListValues(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows) - if err != nil { - return nil, err - } - defer newCol.Release() - newData := array.NewData(arrow.ListOf(newCol.DataType()), listCol.Len(), listCol.Data().Buffers(), []arrow.ArrayData{newCol.Data()}, listCol.NullN(), 0) - defer newData.Release() - return array.NewListData(newData), nil - } else if stringCol, ok := col.(*array.String); ok { - newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) - } - case mapType: - if mapCol, ok := col.(*array.Map); ok { - keyCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).KeyField(), mapCol.Keys(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows) - if err != nil { - return nil, err - } - defer keyCol.Release() - valueCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).ItemField(), mapCol.Items(), fieldMetadata.Fields[1], higherPrecisionEnabled, timestampOption, pool, loc, numRows) - if err != nil { - return nil, err - } - defer valueCol.Release() - - structArr, err := array.NewStructArray([]arrow.Array{keyCol, valueCol}, []string{"k", "v"}) - if err != nil { - return nil, err - } - defer structArr.Release() - newData := array.NewData(arrow.MapOf(keyCol.DataType(), valueCol.DataType()), mapCol.Len(), mapCol.Data().Buffers(), []arrow.ArrayData{structArr.Data()}, mapCol.NullN(), 0) - defer newData.Release() - return array.NewMapData(newData), nil - } else if stringCol, ok := col.(*array.String); ok { - newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) - } - default: - col.Retain() - } - return newCol, nil -} - -// returns n arrow array which will be new and populated if we converted the array to valid utf8 -// or if we didn't covnert it, it will return the original column. -func arrowStringRecordToColumn( - ctx context.Context, - stringCol *array.String, - mem memory.Allocator, - numRows int64, - fieldMetadata fieldMetadata, -) arrow.Array { - if arrowBatchesUtf8ValidationEnabled(ctx) && stringCol.DataType().ID() == arrow.STRING { - tb := array.NewStringBuilder(mem) - defer tb.Release() - - for i := 0; i < int(numRows); i++ { - if stringCol.IsValid(i) { - stringValue := stringCol.Value(i) - if !utf8.ValidString(stringValue) { - logger.WithContext(ctx).Error("Invalid UTF-8 characters detected while reading query response, column: ", fieldMetadata.Name) - stringValue = strings.ToValidUTF8(stringValue, "�") - } - tb.Append(stringValue) - } else { - tb.AppendNull() - } - } - arr := tb.NewArray() - return arr - } - stringCol.Retain() - return stringCol -} - -func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption, withHigherPrecision bool) (*arrow.Schema, error) { - fields := recordToSchemaRecursive(sc.Fields(), rowType, loc, timestampOption, withHigherPrecision) - meta := sc.Metadata() - return arrow.NewSchema(fields, &meta), nil -} - -func recordToSchemaRecursive(inFields []arrow.Field, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption, withHigherPrecision bool) []arrow.Field { - var outFields []arrow.Field - for i, f := range inFields { - fieldMetadata := rowType[i].toFieldMetadata() - converted, t := recordToSchemaSingleField(fieldMetadata, f, withHigherPrecision, timestampOption, loc) - - newField := f - if converted { - newField = arrow.Field{ - Name: f.Name, - Type: t, - Nullable: f.Nullable, - Metadata: f.Metadata, - } - } - outFields = append(outFields, newField) - } - return outFields -} - -func recordToSchemaSingleField(fieldMetadata fieldMetadata, f arrow.Field, withHigherPrecision bool, timestampOption snowflakeArrowBatchesTimestampOption, loc *time.Location) (bool, arrow.DataType) { - t := f.Type - converted := true - switch getSnowflakeType(fieldMetadata.Type) { - case fixedType: - switch f.Type.ID() { - case arrow.DECIMAL: - if withHigherPrecision { - converted = false - } else if fieldMetadata.Scale == 0 { - t = &arrow.Int64Type{} - } else { - t = &arrow.Float64Type{} - } - default: - if withHigherPrecision { - converted = false - } else if fieldMetadata.Scale != 0 { - t = &arrow.Float64Type{} - } else { - converted = false - } - } - case timeType: - t = &arrow.Time64Type{Unit: arrow.Nanosecond} - case timestampNtzType, timestampTzType: - if timestampOption == UseOriginalTimestamp { - // do nothing - return timestamp as is - converted = false - } else if timestampOption == UseMicrosecondTimestamp { - t = &arrow.TimestampType{Unit: arrow.Microsecond} - } else if timestampOption == UseMillisecondTimestamp { - t = &arrow.TimestampType{Unit: arrow.Millisecond} - } else if timestampOption == UseSecondTimestamp { - t = &arrow.TimestampType{Unit: arrow.Second} - } else { - t = &arrow.TimestampType{Unit: arrow.Nanosecond} - } - case timestampLtzType: - if timestampOption == UseOriginalTimestamp { - // do nothing - return timestamp as is - converted = false - } else if timestampOption == UseMicrosecondTimestamp { - t = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: loc.String()} - } else if timestampOption == UseMillisecondTimestamp { - t = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: loc.String()} - } else if timestampOption == UseSecondTimestamp { - t = &arrow.TimestampType{Unit: arrow.Second, TimeZone: loc.String()} - } else { - t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} - } - case objectType: - converted = false - if f.Type.ID() == arrow.STRUCT { - var internalFields []arrow.Field - for idx, internalField := range f.Type.(*arrow.StructType).Fields() { - internalConverted, convertedDataType := recordToSchemaSingleField(fieldMetadata.Fields[idx], internalField, withHigherPrecision, timestampOption, loc) - converted = converted || internalConverted - if internalConverted { - newInternalField := arrow.Field{ - Name: internalField.Name, - Type: convertedDataType, - Metadata: internalField.Metadata, - Nullable: internalField.Nullable, - } - internalFields = append(internalFields, newInternalField) - } else { - internalFields = append(internalFields, internalField) - } - } - t = arrow.StructOf(internalFields...) - } - case arrayType: - if _, ok := f.Type.(*arrow.ListType); ok { - converted, dataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.ListType).ElemField(), withHigherPrecision, timestampOption, loc) - if converted { - t = arrow.ListOf(dataType) - } - } else { - t = f.Type - } - case mapType: - convertedKey, keyDataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.MapType).KeyField(), withHigherPrecision, timestampOption, loc) - convertedValue, valueDataType := recordToSchemaSingleField(fieldMetadata.Fields[1], f.Type.(*arrow.MapType).ItemField(), withHigherPrecision, timestampOption, loc) - converted = convertedKey || convertedValue - if converted { - t = arrow.MapOf(keyDataType, valueDataType) - } - default: - converted = false - } - return converted, t -} - // TypedNullTime is required to properly bind the null value with the snowflakeType as the Snowflake functions // require the type of the field to be provided explicitly for the null values type TypedNullTime struct { diff --git a/converter_batch.go b/converter_batch.go new file mode 100644 index 000000000..ee0368db7 --- /dev/null +++ b/converter_batch.go @@ -0,0 +1,415 @@ +//go:build !nobatch +// +build !nobatch + +package gosnowflake + +import ( + "context" + "fmt" + "math" + "strings" + "time" + "unicode/utf8" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/compute" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +type snowflakeArrowBatchesTimestampOption int + +const ( + // UseNanosecondTimestamp uses arrow.Timestamp in nanosecond precision, could cause ErrTooHighTimestampPrecision if arrow.Timestamp cannot fit original timestamp values. + UseNanosecondTimestamp snowflakeArrowBatchesTimestampOption = iota + // UseMicrosecondTimestamp uses arrow.Timestamp in microsecond precision + UseMicrosecondTimestamp + // UseMillisecondTimestamp uses arrow.Timestamp in millisecond precision + UseMillisecondTimestamp + // UseSecondTimestamp uses arrow.Timestamp in second precision + UseSecondTimestamp + // UseOriginalTimestamp uses original timestamp struct returned by Snowflake. It can be used in case arrow.Timestamp cannot fit original timestamp values. + UseOriginalTimestamp +) + +// ArrowSnowflakeTimestampToTime converts original timestamp returned by Snowflake to time.Time +func (rb *ArrowBatch) ArrowSnowflakeTimestampToTime(rec arrow.Record, colIdx int, recIdx int) *time.Time { + scale := int(rb.scd.RowSet.RowType[colIdx].Scale) + dbType := rb.scd.RowSet.RowType[colIdx].Type + return arrowSnowflakeTimestampToTime(rec.Column(colIdx), getSnowflakeType(dbType), scale, recIdx, rb.loc) +} + +func arrowBatchesUtf8ValidationEnabled(ctx context.Context) bool { + v := ctx.Value(enableArrowBatchesUtf8Validation) + if v == nil { + return false + } + d, ok := v.(bool) + return ok && d +} + +func getArrowBatchesTimestampOption(ctx context.Context) snowflakeArrowBatchesTimestampOption { + v := ctx.Value(arrowBatchesTimestampOption) + if v == nil { + return UseNanosecondTimestamp + } + o, ok := v.(snowflakeArrowBatchesTimestampOption) + if !ok { + return UseNanosecondTimestamp + } + return o +} + +func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) { + arrowBatchesTimestampOption := getArrowBatchesTimestampOption(ctx) + higherPrecisionEnabled := higherPrecisionEnabled(ctx) + + s, err := recordToSchema(record.Schema(), rowType, loc, arrowBatchesTimestampOption, higherPrecisionEnabled) + if err != nil { + return nil, err + } + + var cols []arrow.Array + numRows := record.NumRows() + ctxAlloc := compute.WithAllocator(ctx, pool) + + for i, col := range record.Columns() { + fieldMetadata := rowType[i].toFieldMetadata() + + newCol, err := arrowToRecordSingleColumn(ctxAlloc, s.Field(i), col, fieldMetadata, higherPrecisionEnabled, arrowBatchesTimestampOption, pool, loc, numRows) + if err != nil { + return nil, err + } + cols = append(cols, newCol) + defer newCol.Release() + } + newRecord := array.NewRecord(s, cols, numRows) + return newRecord, nil +} + +func arrowToRecordSingleColumn(ctx context.Context, field arrow.Field, col arrow.Array, fieldMetadata fieldMetadata, higherPrecisionEnabled bool, timestampOption snowflakeArrowBatchesTimestampOption, pool memory.Allocator, loc *time.Location, numRows int64) (arrow.Array, error) { + var err error + newCol := col + snowflakeType := getSnowflakeType(fieldMetadata.Type) + switch snowflakeType { + case fixedType: + if higherPrecisionEnabled { + // do nothing - return decimal as is + col.Retain() + } else if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 { + var toType arrow.DataType + if fieldMetadata.Scale == 0 { + toType = arrow.PrimitiveTypes.Int64 + } else { + toType = arrow.PrimitiveTypes.Float64 + } + // we're fine truncating so no error for data loss here. + // so we use UnsafeCastOptions. + newCol, err = compute.CastArray(ctx, col, compute.UnsafeCastOptions(toType)) + if err != nil { + return nil, err + } + } else if fieldMetadata.Scale != 0 && col.DataType().ID() != arrow.INT64 { + result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true}, + &compute.ArrayDatum{Value: newCol.Data()}, + compute.NewDatum(math.Pow10(int(fieldMetadata.Scale)))) + if err != nil { + return nil, err + } + defer result.Release() + newCol = result.(*compute.ArrayDatum).MakeArray() + } else if fieldMetadata.Scale != 0 && col.DataType().ID() == arrow.INT64 { + // gosnowflake driver uses compute.Divide() which could bring `integer value not in range: -9007199254740992 to 9007199254740992` error + // if we convert int64 to BigDecimal and then use compute.CastArray to convert BigDecimal to float64, we won't have enough precision. + // e.g 0.1 as (38,19) will result 0.09999999999999999 + values := col.(*array.Int64).Int64Values() + floatValues := make([]float64, len(values)) + for i, val := range values { + floatValues[i], _ = intToBigFloat(val, int64(fieldMetadata.Scale)).Float64() + } + builder := array.NewFloat64Builder(pool) + builder.AppendValues(floatValues, nil) + newCol = builder.NewArray() + builder.Release() + } else { + col.Retain() + } + case timeType: + newCol, err = compute.CastArray(ctx, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns)) + if err != nil { + return nil, err + } + case timestampNtzType, timestampLtzType, timestampTzType: + if timestampOption == UseOriginalTimestamp { + // do nothing - return timestamp as is + col.Retain() + } else { + var unit arrow.TimeUnit + switch timestampOption { + case UseMicrosecondTimestamp: + unit = arrow.Microsecond + case UseMillisecondTimestamp: + unit = arrow.Millisecond + case UseSecondTimestamp: + unit = arrow.Second + case UseNanosecondTimestamp: + unit = arrow.Nanosecond + } + var tb *array.TimestampBuilder + if snowflakeType == timestampLtzType { + tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit, TimeZone: loc.String()}) + } else { + tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit}) + } + defer tb.Release() + + for i := 0; i < int(numRows); i++ { + ts := arrowSnowflakeTimestampToTime(col, snowflakeType, int(fieldMetadata.Scale), i, loc) + if ts != nil { + var ar arrow.Timestamp + switch timestampOption { + case UseMicrosecondTimestamp: + ar = arrow.Timestamp(ts.UnixMicro()) + case UseMillisecondTimestamp: + ar = arrow.Timestamp(ts.UnixMilli()) + case UseSecondTimestamp: + ar = arrow.Timestamp(ts.Unix()) + case UseNanosecondTimestamp: + ar = arrow.Timestamp(ts.UnixNano()) + // in case of overflow in arrow timestamp return error + // this could only happen for nanosecond case + if ts.UTC().Year() != ar.ToTime(arrow.Nanosecond).Year() { + return nil, &SnowflakeError{ + Number: ErrTooHighTimestampPrecision, + SQLState: SQLStateInvalidDataTimeFormat, + Message: fmt.Sprintf("Cannot convert timestamp %v in column %v to Arrow.Timestamp data type due to too high precision. Please use context with WithOriginalTimestamp.", ts.UTC(), fieldMetadata.Name), + } + } + } + tb.Append(ar) + } else { + tb.AppendNull() + } + } + newCol = tb.NewArray() + } + case textType: + if stringCol, ok := col.(*array.String); ok { + newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) + } + case objectType: + if structCol, ok := col.(*array.Struct); ok { + var internalCols []arrow.Array + for i := 0; i < structCol.NumField(); i++ { + internalCol := structCol.Field(i) + newInternalCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.StructType).Field(i), internalCol, fieldMetadata.Fields[i], higherPrecisionEnabled, timestampOption, pool, loc, numRows) + if err != nil { + return nil, err + } + internalCols = append(internalCols, newInternalCol) + defer newInternalCol.Release() + } + var fieldNames []string + for _, f := range field.Type.(*arrow.StructType).Fields() { + fieldNames = append(fieldNames, f.Name) + } + nullBitmap := memory.NewBufferBytes(structCol.NullBitmapBytes()) + numberOfNulls := structCol.NullN() + return array.NewStructArrayWithNulls(internalCols, fieldNames, nullBitmap, numberOfNulls, 0) + } else if stringCol, ok := col.(*array.String); ok { + newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) + } + case arrayType: + if listCol, ok := col.(*array.List); ok { + newCol, err = arrowToRecordSingleColumn(ctx, field.Type.(*arrow.ListType).ElemField(), listCol.ListValues(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows) + if err != nil { + return nil, err + } + defer newCol.Release() + newData := array.NewData(arrow.ListOf(newCol.DataType()), listCol.Len(), listCol.Data().Buffers(), []arrow.ArrayData{newCol.Data()}, listCol.NullN(), 0) + defer newData.Release() + return array.NewListData(newData), nil + } else if stringCol, ok := col.(*array.String); ok { + newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) + } + case mapType: + if mapCol, ok := col.(*array.Map); ok { + keyCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).KeyField(), mapCol.Keys(), fieldMetadata.Fields[0], higherPrecisionEnabled, timestampOption, pool, loc, numRows) + if err != nil { + return nil, err + } + defer keyCol.Release() + valueCol, err := arrowToRecordSingleColumn(ctx, field.Type.(*arrow.MapType).ItemField(), mapCol.Items(), fieldMetadata.Fields[1], higherPrecisionEnabled, timestampOption, pool, loc, numRows) + if err != nil { + return nil, err + } + defer valueCol.Release() + + structArr, err := array.NewStructArray([]arrow.Array{keyCol, valueCol}, []string{"k", "v"}) + if err != nil { + return nil, err + } + defer structArr.Release() + newData := array.NewData(arrow.MapOf(keyCol.DataType(), valueCol.DataType()), mapCol.Len(), mapCol.Data().Buffers(), []arrow.ArrayData{structArr.Data()}, mapCol.NullN(), 0) + defer newData.Release() + return array.NewMapData(newData), nil + } else if stringCol, ok := col.(*array.String); ok { + newCol = arrowStringRecordToColumn(ctx, stringCol, pool, numRows, fieldMetadata) + } + default: + col.Retain() + } + return newCol, nil +} + +// returns n arrow array which will be new and populated if we converted the array to valid utf8 +// or if we didn't covnert it, it will return the original column. +func arrowStringRecordToColumn( + ctx context.Context, + stringCol *array.String, + mem memory.Allocator, + numRows int64, + fieldMetadata fieldMetadata, +) arrow.Array { + if arrowBatchesUtf8ValidationEnabled(ctx) && stringCol.DataType().ID() == arrow.STRING { + tb := array.NewStringBuilder(mem) + defer tb.Release() + + for i := 0; i < int(numRows); i++ { + if stringCol.IsValid(i) { + stringValue := stringCol.Value(i) + if !utf8.ValidString(stringValue) { + logger.WithContext(ctx).Error("Invalid UTF-8 characters detected while reading query response, column: ", fieldMetadata.Name) + stringValue = strings.ToValidUTF8(stringValue, "�") + } + tb.Append(stringValue) + } else { + tb.AppendNull() + } + } + arr := tb.NewArray() + return arr + } + stringCol.Retain() + return stringCol +} + +func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption, withHigherPrecision bool) (*arrow.Schema, error) { + fields := recordToSchemaRecursive(sc.Fields(), rowType, loc, timestampOption, withHigherPrecision) + meta := sc.Metadata() + return arrow.NewSchema(fields, &meta), nil +} + +func recordToSchemaRecursive(inFields []arrow.Field, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption, withHigherPrecision bool) []arrow.Field { + var outFields []arrow.Field + for i, f := range inFields { + fieldMetadata := rowType[i].toFieldMetadata() + converted, t := recordToSchemaSingleField(fieldMetadata, f, withHigherPrecision, timestampOption, loc) + + newField := f + if converted { + newField = arrow.Field{ + Name: f.Name, + Type: t, + Nullable: f.Nullable, + Metadata: f.Metadata, + } + } + outFields = append(outFields, newField) + } + return outFields +} + +func recordToSchemaSingleField(fieldMetadata fieldMetadata, f arrow.Field, withHigherPrecision bool, timestampOption snowflakeArrowBatchesTimestampOption, loc *time.Location) (bool, arrow.DataType) { + t := f.Type + converted := true + switch getSnowflakeType(fieldMetadata.Type) { + case fixedType: + switch f.Type.ID() { + case arrow.DECIMAL: + if withHigherPrecision { + converted = false + } else if fieldMetadata.Scale == 0 { + t = &arrow.Int64Type{} + } else { + t = &arrow.Float64Type{} + } + default: + if withHigherPrecision { + converted = false + } else if fieldMetadata.Scale != 0 { + t = &arrow.Float64Type{} + } else { + converted = false + } + } + case timeType: + t = &arrow.Time64Type{Unit: arrow.Nanosecond} + case timestampNtzType, timestampTzType: + if timestampOption == UseOriginalTimestamp { + // do nothing - return timestamp as is + converted = false + } else if timestampOption == UseMicrosecondTimestamp { + t = &arrow.TimestampType{Unit: arrow.Microsecond} + } else if timestampOption == UseMillisecondTimestamp { + t = &arrow.TimestampType{Unit: arrow.Millisecond} + } else if timestampOption == UseSecondTimestamp { + t = &arrow.TimestampType{Unit: arrow.Second} + } else { + t = &arrow.TimestampType{Unit: arrow.Nanosecond} + } + case timestampLtzType: + if timestampOption == UseOriginalTimestamp { + // do nothing - return timestamp as is + converted = false + } else if timestampOption == UseMicrosecondTimestamp { + t = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: loc.String()} + } else if timestampOption == UseMillisecondTimestamp { + t = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: loc.String()} + } else if timestampOption == UseSecondTimestamp { + t = &arrow.TimestampType{Unit: arrow.Second, TimeZone: loc.String()} + } else { + t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} + } + case objectType: + converted = false + if f.Type.ID() == arrow.STRUCT { + var internalFields []arrow.Field + for idx, internalField := range f.Type.(*arrow.StructType).Fields() { + internalConverted, convertedDataType := recordToSchemaSingleField(fieldMetadata.Fields[idx], internalField, withHigherPrecision, timestampOption, loc) + converted = converted || internalConverted + if internalConverted { + newInternalField := arrow.Field{ + Name: internalField.Name, + Type: convertedDataType, + Metadata: internalField.Metadata, + Nullable: internalField.Nullable, + } + internalFields = append(internalFields, newInternalField) + } else { + internalFields = append(internalFields, internalField) + } + } + t = arrow.StructOf(internalFields...) + } + case arrayType: + if _, ok := f.Type.(*arrow.ListType); ok { + converted, dataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.ListType).ElemField(), withHigherPrecision, timestampOption, loc) + if converted { + t = arrow.ListOf(dataType) + } + } else { + t = f.Type + } + case mapType: + convertedKey, keyDataType := recordToSchemaSingleField(fieldMetadata.Fields[0], f.Type.(*arrow.MapType).KeyField(), withHigherPrecision, timestampOption, loc) + convertedValue, valueDataType := recordToSchemaSingleField(fieldMetadata.Fields[1], f.Type.(*arrow.MapType).ItemField(), withHigherPrecision, timestampOption, loc) + converted = convertedKey || convertedValue + if converted { + t = arrow.MapOf(keyDataType, valueDataType) + } + default: + converted = false + } + return converted, t +}