diff --git a/predict.py b/predict.py index 5b2042be..5d130a7b 100644 --- a/predict.py +++ b/predict.py @@ -10,6 +10,12 @@ import pandas as pd from tqdm import tqdm +import tensorflow as tf +gpus = tf.config.experimental.list_physical_devices('GPU') +# Currently, memory growth needs to be the same across GPUs +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)]) def make_prediction(args): diff --git a/train.py b/train.py index 63608c4f..b61a88b3 100644 --- a/train.py +++ b/train.py @@ -14,6 +14,11 @@ import argparse import warnings +gpus = tf.config.experimental.list_physical_devices('GPU') +# Currently, memory growth needs to be the same across GPUs +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)]) class DataGenerator(tf.keras.utils.Sequence): def __init__(self, wav_paths, labels, sr, dt, n_classes,