@@ -8,33 +8,36 @@ import (
88 "sqlflow.org/gohive/hiveserver2"
99)
1010
11- // Options for opened Hive sessions.
12- type Options struct {
11+ // hiveOptions for opened Hive sessions.
12+ type hiveOptions struct {
1313 PollIntervalSeconds int64
1414 BatchSize int64
1515}
1616
17- type Connection struct {
17+ type hiveConnection struct {
1818 thrift * hiveserver2.TCLIServiceClient
1919 session * hiveserver2.TSessionHandle
20- options Options
20+ options hiveOptions
2121}
2222
23- func (c * Connection ) Begin () (driver.Tx , error ) {
23+ func (c * hiveConnection ) Begin () (driver.Tx , error ) {
2424 return nil , nil
2525}
2626
27- func (c * Connection ) Prepare (query string ) (driver.Stmt , error ) {
28- return nil , nil
27+ func (c * hiveConnection ) Prepare (qry string ) (driver.Stmt , error ) {
28+ if ! c .isOpen () {
29+ return nil , fmt .Errorf ("driver: bad connection" )
30+ }
31+ return & hiveStmt {hc : c , query : qry }, nil
2932}
3033
31- func (c * Connection ) isOpen () bool {
34+ func (c * hiveConnection ) isOpen () bool {
3235 return c .session != nil
3336}
3437
3538// As hiveserver2 thrift api does not provide Ping method,
3639// we use GetInfo instead to check the health of hiveserver2.
37- func (c * Connection ) Ping (ctx context.Context ) (err error ) {
40+ func (c * hiveConnection ) Ping (ctx context.Context ) (err error ) {
3841 getInfoReq := hiveserver2 .NewTGetInfoReq ()
3942 getInfoReq .SessionHandle = c .session
4043 getInfoReq .InfoType = hiveserver2 .TGetInfoType_CLI_SERVER_NAME
@@ -52,7 +55,7 @@ func (c *Connection) Ping(ctx context.Context) (err error) {
5255 return nil
5356}
5457
55- func (c * Connection ) Close () error {
58+ func (c * hiveConnection ) Close () error {
5659 if c .isOpen () {
5760 closeReq := hiveserver2 .NewTCloseSessionReq ()
5861 closeReq .SessionHandle = c .session
@@ -74,7 +77,7 @@ func removeLastSemicolon(s string) string {
7477 return s
7578}
7679
77- func (c * Connection ) execute (ctx context.Context , query string , args []driver.NamedValue ) (* hiveserver2.TExecuteStatementResp , error ) {
80+ func (c * hiveConnection ) execute (ctx context.Context , query string , args []driver.NamedValue ) (* hiveserver2.TExecuteStatementResp , error ) {
7881 executeReq := hiveserver2 .NewTExecuteStatementReq ()
7982 executeReq .SessionHandle = c .session
8083 executeReq .Statement = removeLastSemicolon (query )
@@ -90,15 +93,15 @@ func (c *Connection) execute(ctx context.Context, query string, args []driver.Na
9093 return resp , nil
9194}
9295
93- func (c * Connection ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
96+ func (c * hiveConnection ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
9497 resp , err := c .execute (ctx , query , args )
9598 if err != nil {
9699 return nil , err
97100 }
98101 return newRows (c .thrift , resp .OperationHandle , c .options ), nil
99102}
100103
101- func (c * Connection ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
104+ func (c * hiveConnection ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
102105 resp , err := c .execute (ctx , query , args )
103106 if err != nil {
104107 return nil , err
0 commit comments