Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit bb239fa

Browse files
committed
Using tf.nn.rnn* instead of direct import
1 parent 9bea801 commit bb239fa

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

skflow/models.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from __future__ import division, print_function, absolute_import
1717

1818
import tensorflow as tf
19-
from tensorflow.models.rnn import rnn, rnn_cell
2019

2120
from skflow.ops import mean_squared_error_regressor, softmax_classifier, dnn
2221

@@ -133,16 +132,16 @@ def rnn_estimator(X, y):
133132
raise ValueError("cell_type {} is not supported. ".format(cell_type))
134133
if bidirection:
135134
# forward direction cell
136-
rnn_fw_cell = rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
135+
rnn_fw_cell = tf.nn.rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
137136
# backward direction cell
138-
rnn_bw_cell = rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
137+
rnn_bw_cell = tf.nn.rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
139138
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
140-
encoding = rnn.bidirectional_rnn(rnn_fw_cell, rnn_bw_cell,
139+
encoding = tf.nn.rnn.bidirectional_rnn(rnn_fw_cell, rnn_bw_cell,
141140
sequence_length=sequence_length,
142141
initial_state=initial_state)
143142
else:
144-
cell = rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
145-
_, encoding = rnn.rnn(cell, X, dtype=tf.float32,
143+
cell = tf.nn.rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
144+
_, encoding = tf.nn.rnn.rnn(cell, X, dtype=tf.float32,
146145
sequence_length=sequence_length,
147146
initial_state=initial_state)
148147
return target_predictor_fn(encoding[-1], y)

0 commit comments

Comments
 (0)