diff --git a/models/audio/speech_recognition/conformer/igie/ci/prepare.sh b/models/audio/speech_recognition/conformer/igie/ci/prepare.sh index 5553e220bffd854cda6da27e6ec1be5b9b1937cb..9736635bb0365c147de0140a37e64a148999040e 100644 --- a/models/audio/speech_recognition/conformer/igie/ci/prepare.sh +++ b/models/audio/speech_recognition/conformer/igie/ci/prepare.sh @@ -28,7 +28,7 @@ pip3 install -r requirements.txt ln -s /mnt/deepspark/data/checkpoints/conformer_checkpoints.tar ./ tar xf conformer_checkpoints.tar -ln -s /mnt/deepspark/data/datasets/aishell_test_data ./aishell_test_data +ln -sT /mnt/deepspark/data/datasets/aishell_test_data ./aishell_test_data # cp /mnt/deepspark/data/datasets/aishell_test_data.tar ./ # tar xf aishell_test_data.tar -# bash scripts/aishell_data_prepare.sh ./aishell_test_data ./tools \ No newline at end of file +# bash scripts/aishell_data_prepare.sh ./aishell_test_data ./tools diff --git a/models/audio/speech_recognition/conformer/igie/wenet/processor.py b/models/audio/speech_recognition/conformer/igie/wenet/processor.py index 9a542a3d204cdb3def8cf61ce0b0fd8bb31af32e..4337c99c8aebc1172560b7d130cc4ff61392f483 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/processor.py +++ b/models/audio/speech_recognition/conformer/igie/wenet/processor.py @@ -109,15 +109,54 @@ def tar_file_and_group(data): sample['stream'].close() -def parse_raw(data): - """ Parse key/wav/txt from json line - - Args: - data: Iterable[str], str is a json line has key/wav/txt +#def parse_raw(data): +# """ Parse key/wav/txt from json line +# +# Args: +# data: Iterable[str], str is a json line has key/wav/txt +# +# Returns: +# Iterable[{key, wav, txt, sample_rate}] +# """ +# #import soundfile as sf +# for sample in data: +# assert 'src' in sample +# json_line = sample['src'] +# obj = json.loads(json_line) +# assert 'key' in obj +# assert 'wav' in obj +# assert 'txt' in obj +# key = obj['key'] +# wav_file = obj['wav'] +# txt = obj['txt'] +# try: +# if 'start' in obj: +# assert 'end' in obj +# sample_rate = torchaudio.backend.sox_io_backend.info( +# wav_file).sample_rate +# start_frame = int(obj['start'] * sample_rate) +# end_frame = int(obj['end'] * sample_rate) +# waveform, _ = torchaudio.backend.sox_io_backend.load( +# filepath=wav_file, +# num_frames=end_frame - start_frame, +# frame_offset=start_frame, +# backend="soundfile" +# ) +# else: +# #waveform, sample_rate = torchaudio.load(wav_file, backend="soundfile") +# waveform, sample_rate = torchaudio.load(wav_file) +# example = dict(key=key, +# txt=txt, +# wav=waveform, +# sample_rate=sample_rate) +# yield example +# except Exception as ex: +# #logging.warning('Failed to read {}'.format(wav_file)) +# logging.warning('Failed to read {}. Error: {}'.format(wav_file, ex), exc_info=True) - Returns: - Iterable[{key, wav, txt, sample_rate}] - """ +def parse_raw(data): + import soundfile as sf + import numpy as np for sample in data: assert 'src' in sample json_line = sample['src'] @@ -128,26 +167,39 @@ def parse_raw(data): key = obj['key'] wav_file = obj['wav'] txt = obj['txt'] + try: if 'start' in obj: assert 'end' in obj - sample_rate = torchaudio.backend.sox_io_backend.info( - wav_file).sample_rate + # 1. 获取信息 + audio_info = sf.info(wav_file) + sample_rate = audio_info.samplerate + + # 2. 计算帧数 start_frame = int(obj['start'] * sample_rate) end_frame = int(obj['end'] * sample_rate) - waveform, _ = torchaudio.backend.sox_io_backend.load( - filepath=wav_file, - num_frames=end_frame - start_frame, - frame_offset=start_frame) + num_frames = end_frame - start_frame + + # 3. 读取音频 (soundfile 返回 numpy 数组) + # 注意:soundfile 读取的数据通常是 [Time, Channels],而 PyTorch 常用 [Channels, Time] + data_np, _ = sf.read(wav_file, start=start_frame, stop=end_frame) else: - waveform, sample_rate = torchaudio.load(wav_file) - example = dict(key=key, - txt=txt, - wav=waveform, - sample_rate=sample_rate) + # 读取全部 + data_np, sample_rate = sf.read(wav_file) + + # 4. 转换为 Tensor 并调整维度 + # 如果是一维数组 (单声道),增加一个维度变成 [1, Time] + if data_np.ndim == 1: + waveform = torch.from_numpy(data_np).unsqueeze(0) + else: + # 如果是多维 (立体声),转置为 [Channels, Time] + waveform = torch.from_numpy(data_np.T) + + example = dict(key=key, txt=txt, wav=waveform, sample_rate=sample_rate) yield example + except Exception as ex: - logging.warning('Failed to read {}'.format(wav_file)) + logging.warning(f'Failed to read {wav_file}. Error: {ex}', exc_info=True) def filter(data,