From 8cd3b876d825f2c373a0d785c20821c77f13a0aa Mon Sep 17 00:00:00 2001 From: Katsuya Ishiyama Date: Fri, 16 Feb 2018 02:12:10 +0900 Subject: [PATCH] fix import error --- i2v/chainer_i2v.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/i2v/chainer_i2v.py b/i2v/chainer_i2v.py index 30ee815f..388500ce 100644 --- a/i2v/chainer_i2v.py +++ b/i2v/chainer_i2v.py @@ -2,11 +2,16 @@ import json import warnings import numpy as np +from distutils.version import StrictVersion from scipy.ndimage import zoom from skimage.transform import resize +import chainer from chainer import Variable from chainer.functions import average_pooling_2d, sigmoid -from chainer.functions.caffe import CaffeFunction +try: + from chainer.functions.caffe import CaffeFunction +except: + from chainer.links.caffe import CaffeFunction class ChainerI2V(Illustration2VecBase): @@ -47,7 +52,14 @@ def _forward(self, inputs, layername): input_ -= self.mean # subtract mean input_ = input_.transpose((0, 3, 1, 2)) # (N, H, W, C) -> (N, C, H, W) x = Variable(input_) - y, = self.net(inputs={'data': x}, outputs=[layername], train=False) + + # train argument is not supported from Ver2. + if StrictVersion(chainer.__version__) < StrictVersion('2.0.0'): + y, = self.net(inputs={'data': x}, outputs=[layername], train=False) + else: + chainer.using_config('train', False) + y, = self.net(inputs={'data': x}, outputs=[layername]) + return y def _extract(self, inputs, layername):