-
Notifications
You must be signed in to change notification settings - Fork 192
Open
Description
from tensorflow.keras.models import load_model
from clean import downsample_mono, envelope
from kapre.time_frequency import STFT, Magnitude, ApplyFilterbank, MagnitudeToDecibel
from sklearn.preprocessing import LabelEncoder
import numpy as np
import os
import argparse
def make_prediction(args):
model = load_model(args.model_fn,
custom_objects={'STFT': STFT,
'Magnitude': Magnitude,
'ApplyFilterbank': ApplyFilterbank,
'MagnitudeToDecibel': MagnitudeToDecibel})
classes = ['C', 'D', 'E', 'F', 'L', 'O', 'P', 'T', 'Z']
wav_fn = args.wav_fn
if not os.path.exists(wav_fn):
raise FileNotFoundError(f"File {wav_fn} tidak ditemukan.")
le = LabelEncoder()
le.fit(classes)
rate, wav = downsample_mono(wav_fn, args.sr)
mask, env = envelope(wav, rate, threshold=args.threshold)
clean_wav = wav[mask]
step = int(args.sr * args.dt)
batch = []
for i in range(0, clean_wav.shape[0], step):
sample = clean_wav[i:i + step]
sample = sample.reshape(-1, 1)
if sample.shape[0] < step:
tmp = np.zeros(shape=(step, 1), dtype=np.float32)
tmp[:sample.shape[0], :] = sample.flatten().reshape(-1, 1)
sample = tmp
batch.append(sample)
X_batch = np.array(batch, dtype=np.float32)
y_pred = model.predict(X_batch)
y_mean = np.mean(y_pred, axis=0)
y_pred_class = np.argmax(y_mean)
print(f'Predicted class: {classes[y_pred_class]}')
np.save(os.path.join('logs', args.pred_fn), y_mean)
if name == 'main':
parser = argparse.ArgumentParser(description='Audio Classification Prediction')
parser.add_argument('--model_fn', type=str, default='models/lstm.h5',
help='model file to make predictions')
parser.add_argument('--pred_fn', type=str, default='y_pred',
help='fn to write predictions in logs dir')
parser.add_argument('--wav_fn', type=str, default='record/test.WAV',
help='file wav to predict')
parser.add_argument('--dt', type=float, default=1.0,
help='time in seconds to sample audio')
parser.add_argument('--sr', type=int, default=16000,
help='sample rate of clean audio')
parser.add_argument('--threshold', type=str, default=20,
help='threshold magnitude for np.int16 dtype')
args, _ = parser.parse_known_args()
make_prediction(args)
Metadata
Metadata
Assignees
Labels
No labels