Skip to content

Commit 65556f4

Browse files
bani-intelaipgsayantan-nervana
authored andcommitted
merge NGTF-2442 & NGTF-2515 into r0.19 (#455)
1 parent 150554d commit 65556f4

File tree

2 files changed

+116
-11
lines changed

2 files changed

+116
-11
lines changed

ngraph_bridge/ngraph_utils.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,32 +128,34 @@ Status ValuesFromConstNode(const NodeDef& node,
128128
n_elements *= shape.dim(i).size();
129129
}
130130
values->resize(n_elements);
131+
auto val_lastsaved = (T)0; // cast
131132
for (auto i = 0; i < n_elements; i++) {
132133
auto& tensor = node.attr().at("value").tensor();
133134
auto dt = node.attr().at("dtype").type();
135+
int64 val_size = 0;
136+
auto val_i = (T)0; // cast
134137
switch (dt) {
135138
// TODO(amprocte/NGRAPH-2502): there are more element types to support
136139
// here
137140
case DT_INT32:
138-
(*values)[i] = (tensor.int_val_size() == 1 ? tensor.int_val()[0]
139-
: tensor.int_val()[i]);
141+
val_size = tensor.int_val_size();
142+
if (val_size > i) val_i = tensor.int_val()[i];
140143
break;
141144
case DT_INT64:
142-
(*values)[i] = (tensor.int64_val_size() == 1 ? tensor.int64_val()[0]
143-
: tensor.int64_val()[i]);
145+
val_size = tensor.int64_val_size();
146+
if (val_size > i) val_i = tensor.int64_val()[i];
144147
break;
145148
case DT_FLOAT:
146-
(*values)[i] = (tensor.float_val_size() == 1 ? tensor.float_val()[0]
147-
: tensor.float_val()[i]);
149+
val_size = tensor.float_val_size();
150+
if (val_size > i) val_i = tensor.float_val()[i];
148151
break;
149152
case DT_BOOL:
150-
(*values)[i] = (tensor.bool_val_size() == 1 ? tensor.bool_val()[0]
151-
: tensor.bool_val()[i]);
153+
val_size = tensor.bool_val_size();
154+
if (val_size > i) val_i = tensor.bool_val()[i];
152155
break;
153156
case DT_DOUBLE:
154-
(*values)[i] =
155-
(tensor.double_val_size() == 1 ? tensor.double_val()[0]
156-
: tensor.double_val()[i]);
157+
val_size = tensor.double_val_size();
158+
if (val_size > i) val_i = tensor.double_val()[i];
157159
break;
158160
default:
159161
NGRAPH_VLOG(0)
@@ -165,6 +167,14 @@ Status ValuesFromConstNode(const NodeDef& node,
165167
DataType_Name(dt),
166168
" on an empty tensor");
167169
}
170+
if (val_size == 0) {
171+
return errors::InvalidArgument("Empty values vector");
172+
} else if (i < val_size) {
173+
(*values)[i] = val_i;
174+
val_lastsaved = val_i;
175+
} else {
176+
(*values)[i] = val_lastsaved;
177+
}
168178
}
169179
} else {
170180
values->resize(tensor_content_size / sizeof(VecT));

test/python/test_const.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# ==============================================================================
2+
# Copyright 2018-2019 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""nGraph TensorFlow bridge Const operation test
17+
18+
"""
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import pytest
24+
25+
import tensorflow as tf
26+
import os
27+
28+
from common import NgraphTest
29+
30+
# Uncomment for debugging; Also add -s in command like, e.g.
31+
# (venv-tf-py3) [build_cmake]$
32+
# NGRAPH_TF_LOG_PLACEMENT=1 NGRAPH_TF_VLOG_LEVEL=6 pytest -s -k test_const_scalarval ../test/python/test_const.py
33+
import logging
34+
logging.basicConfig(level=logging.DEBUG)
35+
36+
37+
class TestConstOperations(NgraphTest):
38+
39+
def test_const_listvals(self):
40+
zz = tf.constant([1, 2, 3, 4, 5, 6], dtype=float, shape=[2, 3])
41+
42+
def run_test(sess):
43+
return sess.run(zz)
44+
45+
assert (
46+
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()
47+
48+
def test_const_listvals_2(self):
49+
zz = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=float, shape=[2, 3])
50+
51+
def run_test(sess):
52+
return sess.run(zz)
53+
54+
assert (
55+
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()
56+
57+
def test_const_scalarval(self):
58+
zz = tf.constant(-3, dtype=float, shape=[2, 3])
59+
60+
def run_test(sess):
61+
return sess.run(zz)
62+
63+
assert (
64+
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()
65+
66+
def test_const_lastfill(self):
67+
zz = tf.constant([1, 2], dtype=float, shape=[2, 3])
68+
69+
def run_test(sess):
70+
return sess.run(zz)
71+
72+
assert (
73+
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()
74+
75+
def test_const_empty(self):
76+
log = logging.getLogger('test_const_empty')
77+
zz = tf.constant([], dtype=float, shape=[2, 3])
78+
79+
def run_test(sess):
80+
log.debug('Invoking sess.run(zz)')
81+
return sess.run(zz)
82+
83+
# Ideally we want same behavior for both TF & NG, but for now we are deviating,
84+
# NGraph will throw error, but TF will fill in zeros
85+
# assert (
86+
# self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()
87+
88+
# Test to see that exception is raised in NG
89+
try:
90+
# This test is expected to fail currently
91+
res = self.with_ngraph(run_test)
92+
assert False, 'Failed, expected test to raise error'
93+
except:
94+
log.debug('Passed, expected NG to raise error...')
95+
assert True

0 commit comments

Comments
 (0)