@@ -42,7 +42,12 @@ database=""
4242database="{{.Database}}"
4343{{end}}
4444
45- conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")
45+ session_cfg = {}
46+ {{ range $k, $v := .Session }}
47+ session_cfg["{{$k}}"] = "{{$v}}"
48+ {{end}}
49+
50+ conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg)
4651
4752feature_column_names = [{{range .X}}
4853"{{.FeatureName}}",
@@ -70,11 +75,6 @@ feature_metas["{{$value.FeatureName}}"] = {
7075}
7176{{end}}
7277
73- session_cfg = {}
74- {{ range $k, $v := .Session }}
75- session_cfg["{{$k}}"] = "{{$v}}"
76- {{end}}
77-
7878def get_dtype(type_str):
7979 if type_str == "float32":
8080 return tf.float32
@@ -104,7 +104,7 @@ def input_fn(datasetStr):
104104 else:
105105 feature_types.append(get_dtype(feature_metas[name]["dtype"]))
106106
107- gen = db_generator(driver, conn, session_cfg, datasetStr, feature_column_names, "{{.Y.FeatureName}}", feature_metas)
107+ gen = db_generator(driver, conn, datasetStr, feature_column_names, "{{.Y.FeatureName}}", feature_metas)
108108 dataset = tf.data.Dataset.from_generator(gen, (tuple(feature_types), tf.{{.Y.Dtype}}))
109109 ds_mapper = functools.partial(_parse_sparse_feature, feature_metas=feature_metas)
110110 return dataset.map(ds_mapper)
@@ -169,7 +169,12 @@ database="{{.Database}}"
169169database=""
170170{{end}}
171171
172- conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")
172+ session_cfg = {}
173+ {{ range $k, $v := .Session }}
174+ session_cfg["{{$k}}"] = "{{$v}}"
175+ {{end}}
176+
177+ conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg)
173178
174179feature_column_names = [{{range .X}}
175180"{{.FeatureName}}",
@@ -197,11 +202,6 @@ feature_metas["{{$value.FeatureName}}"] = {
197202}
198203{{end}}
199204
200- session_cfg = {}
201- {{ range $k, $v := .Session }}
202- session_cfg["{{$k}}"] = "{{$v}}"
203- {{end}}
204-
205205def get_dtype(type_str):
206206 if type_str == "float32":
207207 return tf.float32
@@ -232,7 +232,7 @@ def eval_input_fn(batch_size):
232232 else:
233233 feature_types.append(get_dtype(feature_metas[name]["dtype"]))
234234
235- gen = db_generator(driver, conn, session_cfg, """{{.PredictionDatasetSQL}}""",
235+ gen = db_generator(driver, conn, """{{.PredictionDatasetSQL}}""",
236236 feature_column_names, "{{.Y.FeatureName}}", feature_metas)
237237 dataset = tf.data.Dataset.from_generator(gen, (tuple(feature_types), tf.{{.Y.Dtype}}))
238238 ds_mapper = functools.partial(_parse_sparse_feature, feature_metas=feature_metas)
@@ -322,7 +322,7 @@ class FastPredict:
322322
323323column_names = feature_column_names[:]
324324column_names.append("{{.Y.FeatureName}}")
325- pred_gen = db_generator(driver, conn, session_cfg, """{{.PredictionDatasetSQL}}""", feature_column_names, "{{.Y.FeatureName}}", feature_metas)()
325+ pred_gen = db_generator(driver, conn, """{{.PredictionDatasetSQL}}""", feature_column_names, "{{.Y.FeatureName}}", feature_metas)()
326326fast_predictor = FastPredict(classifier, fast_input_fn)
327327
328328with buffered_db_writer(driver, conn, "{{.TableName}}", column_names, 100) as w:
0 commit comments