Skip to content

Commit 7204a35

Browse files
authored
bugfix: no panic on empty table (apache#613)
1 parent 6e570e2 commit 7204a35

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

datafusion/tests/test_context.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,37 @@ def test_from_arrow_table_with_name(ctx):
139139
assert tables[0] == "tbl"
140140

141141

142+
def test_from_arrow_table_empty(ctx):
143+
data = {"a": [], "b": []}
144+
schema = pa.schema([("a", pa.int32()), ("b", pa.string())])
145+
table = pa.Table.from_pydict(data, schema=schema)
146+
147+
# convert to DataFrame
148+
df = ctx.from_arrow_table(table)
149+
tables = list(ctx.tables())
150+
151+
assert df
152+
assert len(tables) == 1
153+
assert isinstance(df, DataFrame)
154+
assert set(df.schema().names) == {"a", "b"}
155+
assert len(df.collect()) == 0
156+
157+
158+
def test_from_arrow_table_empty_no_schema(ctx):
159+
data = {"a": [], "b": []}
160+
table = pa.Table.from_pydict(data)
161+
162+
# convert to DataFrame
163+
df = ctx.from_arrow_table(table)
164+
tables = list(ctx.tables())
165+
166+
assert df
167+
assert len(tables) == 1
168+
assert isinstance(df, DataFrame)
169+
assert set(df.schema().names) == {"a", "b"}
170+
assert len(df.collect()) == 0
171+
172+
142173
def test_from_pylist(ctx):
143174
# create a dataframe from Python list
144175
data = [

src/context.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::store::StorageContexts;
3939
use crate::udaf::PyAggregateUDF;
4040
use crate::udf::PyScalarUDF;
4141
use crate::utils::{get_tokio_runtime, wait_for_future};
42-
use datafusion::arrow::datatypes::{DataType, Schema};
42+
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4343
use datafusion::arrow::pyarrow::PyArrowType;
4444
use datafusion::arrow::record_batch::RecordBatch;
4545
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -344,9 +344,15 @@ impl PySessionContext {
344344
&mut self,
345345
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
346346
name: Option<&str>,
347+
schema: Option<PyArrowType<Schema>>,
347348
py: Python,
348349
) -> PyResult<PyDataFrame> {
349-
let schema = partitions.0[0][0].schema();
350+
let schema = if let Some(schema) = schema {
351+
SchemaRef::from(schema.0)
352+
} else {
353+
partitions.0[0][0].schema()
354+
};
355+
350356
let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;
351357

352358
// generate a random (unique) name for this table if none is provided
@@ -428,12 +434,15 @@ impl PySessionContext {
428434
// Instantiate pyarrow Table object & convert to batches
429435
let table = data.call_method0(py, "to_batches")?;
430436

437+
let schema = data.getattr(py, "schema")?;
438+
let schema = schema.extract::<PyArrowType<Schema>>(py)?;
439+
431440
// Cast PyObject to RecordBatch type
432441
// Because create_dataframe() expects a vector of vectors of record batches
433442
// here we need to wrap the vector of record batches in an additional vector
434443
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
435444
let list_of_batches = PyArrowType::from(vec![batches.0]);
436-
self.create_dataframe(list_of_batches, name, py)
445+
self.create_dataframe(list_of_batches, name, Some(schema), py)
437446
})
438447
}
439448

0 commit comments

Comments
 (0)