@@ -39,7 +39,7 @@ use crate::store::StorageContexts;
39
39
use crate :: udaf:: PyAggregateUDF ;
40
40
use crate :: udf:: PyScalarUDF ;
41
41
use crate :: utils:: { get_tokio_runtime, wait_for_future} ;
42
- use datafusion:: arrow:: datatypes:: { DataType , Schema } ;
42
+ use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
43
43
use datafusion:: arrow:: pyarrow:: PyArrowType ;
44
44
use datafusion:: arrow:: record_batch:: RecordBatch ;
45
45
use datafusion:: datasource:: file_format:: file_compression_type:: FileCompressionType ;
@@ -344,9 +344,15 @@ impl PySessionContext {
344
344
& mut self ,
345
345
partitions : PyArrowType < Vec < Vec < RecordBatch > > > ,
346
346
name : Option < & str > ,
347
+ schema : Option < PyArrowType < Schema > > ,
347
348
py : Python ,
348
349
) -> PyResult < PyDataFrame > {
349
- let schema = partitions. 0 [ 0 ] [ 0 ] . schema ( ) ;
350
+ let schema = if let Some ( schema) = schema {
351
+ SchemaRef :: from ( schema. 0 )
352
+ } else {
353
+ partitions. 0 [ 0 ] [ 0 ] . schema ( )
354
+ } ;
355
+
350
356
let table = MemTable :: try_new ( schema, partitions. 0 ) . map_err ( DataFusionError :: from) ?;
351
357
352
358
// generate a random (unique) name for this table if none is provided
@@ -428,12 +434,15 @@ impl PySessionContext {
428
434
// Instantiate pyarrow Table object & convert to batches
429
435
let table = data. call_method0 ( py, "to_batches" ) ?;
430
436
437
+ let schema = data. getattr ( py, "schema" ) ?;
438
+ let schema = schema. extract :: < PyArrowType < Schema > > ( py) ?;
439
+
431
440
// Cast PyObject to RecordBatch type
432
441
// Because create_dataframe() expects a vector of vectors of record batches
433
442
// here we need to wrap the vector of record batches in an additional vector
434
443
let batches = table. extract :: < PyArrowType < Vec < RecordBatch > > > ( py) ?;
435
444
let list_of_batches = PyArrowType :: from ( vec ! [ batches. 0 ] ) ;
436
- self . create_dataframe ( list_of_batches, name, py)
445
+ self . create_dataframe ( list_of_batches, name, Some ( schema ) , py)
437
446
} )
438
447
}
439
448
0 commit comments