diff --git a/build.sh b/build.sh index 72a6621..af6be1e 100755 --- a/build.sh +++ b/build.sh @@ -1,9 +1,11 @@ #!/usr/bin/env bash echo Building Native ops... TF_INC=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') +TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) +TF_LFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) # TODO: GPU support #nvcc -std=c++11 -c -o count_sketch.cu.o count_sketch.cu.cc -I $TF_INC -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -D_MWAITXINTRIN_H_INCLUDED -g #g++ -std=c++11 -shared -o count_sketch.so count_sketch.cc count_sketch.cu.o -fPIC -lcudart -I $TF_INC -D_GLIBCXX_USE_CXX11_ABI=0 -g mkdir -p build -g++ -std=c++11 -shared -o build/count_sketch.so ops/count_sketch.cc -fPIC -I $TF_INC -D_GLIBCXX_USE_CXX11_ABI=0 \ No newline at end of file +g++ -std=c++11 -shared -o build/count_sketch.so ops/count_sketch.cc -fPIC -I $TF_INC -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 diff --git a/count_sketch.py b/count_sketch.py index 33da810..4970307 100644 --- a/count_sketch.py +++ b/count_sketch.py @@ -1,6 +1,10 @@ import tensorflow as tf +import pkg_resources -_sketch_op = tf.load_op_library('./build/count_sketch.so') +path = "./build/count_sketch.so" +filepath = pkg_resources.resource_filename(__name__, path) + +_sketch_op = tf.load_op_library(filepath) def count_sketch(probs, project_size): """ Calculates count-min sketch of a tensor. @@ -53,5 +57,5 @@ def bilinear_pool(x1, x2, output_size): pc1 = tf.complex(p1, tf.zeros_like(p1)) pc2 = tf.complex(p2, tf.zeros_like(p2)) - conved = tf.batch_ifft(tf.batch_fft(pc1) * tf.batch_fft(pc2)) + conved = tf.ifft(tf.fft(pc1) * tf.fft(pc2)) return tf.real(conved)