Skip to content

Could not located class: Reshape Predict: in python 3.10 an tensorflow 2.16  #78

@Jeprizal

Description

@Jeprizal

from tensorflow.keras.models import load_model
from clean import downsample_mono, envelope
from kapre.time_frequency import STFT, Magnitude, ApplyFilterbank, MagnitudeToDecibel
from sklearn.preprocessing import LabelEncoder
import numpy as np
import os
import argparse

def make_prediction(args):

model = load_model(args.model_fn,
    custom_objects={'STFT': STFT,
                    'Magnitude': Magnitude,
                    'ApplyFilterbank': ApplyFilterbank,
                    'MagnitudeToDecibel': MagnitudeToDecibel})


classes = ['C', 'D', 'E', 'F', 'L', 'O', 'P', 'T', 'Z']

wav_fn = args.wav_fn
if not os.path.exists(wav_fn):
    raise FileNotFoundError(f"File {wav_fn} tidak ditemukan.")

le = LabelEncoder()
le.fit(classes)


rate, wav = downsample_mono(wav_fn, args.sr)
mask, env = envelope(wav, rate, threshold=args.threshold)
clean_wav = wav[mask]
step = int(args.sr * args.dt)
batch = []

for i in range(0, clean_wav.shape[0], step):
    sample = clean_wav[i:i + step]
    sample = sample.reshape(-1, 1)
    if sample.shape[0] < step:
        tmp = np.zeros(shape=(step, 1), dtype=np.float32)
        tmp[:sample.shape[0], :] = sample.flatten().reshape(-1, 1)
        sample = tmp
    batch.append(sample)

X_batch = np.array(batch, dtype=np.float32)
y_pred = model.predict(X_batch)
y_mean = np.mean(y_pred, axis=0)
y_pred_class = np.argmax(y_mean)

print(f'Predicted class: {classes[y_pred_class]}')

np.save(os.path.join('logs', args.pred_fn), y_mean)

if name == 'main':

parser = argparse.ArgumentParser(description='Audio Classification Prediction')
parser.add_argument('--model_fn', type=str, default='models/lstm.h5',
                    help='model file to make predictions')
parser.add_argument('--pred_fn', type=str, default='y_pred',
                    help='fn to write predictions in logs dir')
parser.add_argument('--wav_fn', type=str, default='record/test.WAV',
                    help='file wav to predict')
parser.add_argument('--dt', type=float, default=1.0,
                    help='time in seconds to sample audio')
parser.add_argument('--sr', type=int, default=16000,
                    help='sample rate of clean audio')
parser.add_argument('--threshold', type=str, default=20,
                    help='threshold magnitude for np.int16 dtype')
args, _ = parser.parse_known_args()

make_prediction(args)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions