wavlm-large模型onnx和mnn版本的导出与使用

在音频处理领域,WavLM是一个强大的预训练模型,可以提取高质量的音频特征表示。然而,在实际应用中,我们常常面临模型体积大、推理速度慢以及部署复杂等问题。本文将详细记录我对s3prl项目中的WavLM模型进行优化,包括简化推理过程、ONNX导出以及MNN转换(包括FP16和INT8量化)的过程,用于解决上面提到的问题。

代码仓库:https://github.com/vra/s3prl, README详细记录了导出过程与使用说明

ONNX模型:https://huggingface.co/yunfengwang/wavlm-large-onnx

MNN模型:https://huggingface.co/yunfengwang/wavlm-large-mnn

Pytorch版本的问题

在开始优化之前,让我们先了解原始WavLM模型在s3prl项目中存在的几个问题:

网络依赖复杂:原始代码依赖于torch.hub加载模型,这种模式会从GitHub上拉取代码,这在国内环境中会遇到连接问题。
推理流程繁琐:需要加载完整的s3prl库和相关依赖,代码冗长且不直观。
模型体积大:原始PyTorch模型体积较大,不适合在资源受限的设备上部署。
推理速度慢:没有针对推理场景进行优化,存在不必要的计算开销。

优化过程

针对上述问题,实施了以下优化方案:

简化推理代码:创建了独立的demo.py文件,封装了一个简洁的VoiceEmbeddingExtractor类。
ONNX导出:修改了s3prl/upstream/wavlm/expert.py文件,支持将模型导出为ONNX格式。
MNN转换:将ONNX模型转换为MNN格式,并进行FP16和INT8量化。
推理示例代码:分别为PyTorch、ONNX和MNN模型提供了简洁的推理示例代码。
下面,我将详细介绍每一步的实现过程和技术细节。

简化推理代码

首先,参考UniSpeech中的代码,创建了一个利用WavLM进行音频embedding提取的demo.py文件,其中包含网络的定义,预训练权重的加载,以及一个VoiceEmbeddingExtractor类,用于封装WavLM模型的推理过程。这个类的设计非常简洁,主要包含以下功能:

初始化函数:加载预训练的WavLM模型。
extract_embedding函数:接收音频文件路径,返回提取的音频特征向量。
export_onnx函数:将PyTorch模型导出为ONNX格式。
关键代码片段如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class VoiceEmbeddingExtractor:
def __init__(self, ckpt_path, model_name="wavlm_large", device="cuda") -> None:
self.device = device
self.model = init_model(model_name, ckpt_path).to(self.device).eval()

def extract_embedding(self, audio_path):
audio = whisper.load_audio(audio_path)
audio = torch.from_numpy(audio).to(self.device).unsqueeze(0).float()
with torch.no_grad():
embed = self.model(audio)
return embed

def export_onnx(self, onnx_save_path):
dynamic_axes = {
"audio": {0: "batch_size", 1: "seq_len"},
"embed": {0: "batch_size"},
}
dummy_audio = torch.randn((1, 1024)).float().to(self.device)
torch.onnx.export(
self.model,
(dummy_audio,),
onnx_save_path,
input_names=["audio"],
output_names=["embed"],
dynamic_axes=dynamic_axes,
)

这段代码大大简化了WavLM模型的使用方式,用户只需提供预训练模型路径和音频文件路径,即可轻松提取音频特征。

ONNX导出

为了支持ONNX导出,修改s3prl/upstream/wavlm/expert.py文件,主要修改包括:

确保模型的forward函数支持ONNX导出格式。
处理动态输入尺寸,使导出的ONNX模型能够处理不同长度的音频输入。
简化模型结构,移除训练时特有的组件,只保留推理所需的部分。
在demo.py中,我实现了export_onnx方法,用于将PyTorch模型导出为ONNX格式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def export_onnx(self, onnx_save_path):
dynamic_axes = {
"audio": {0: "batch_size", 1: "seq_len"},
"embed": {0: "batch_size"},
}
dummy_audio = torch.randn((1, 1024)).float().to(self.device)
torch.onnx.export(
self.model,
(dummy_audio,),
onnx_save_path,
input_names=["audio"],
output_names=["embed"],
dynamic_axes=dynamic_axes,
)

导出ONNX模型的命令非常简单:

1
python demo.py -t export -m /path/to/wavlm_large_finetune.pth -o wavlm_large.onnx

MNN转换

将ONNX模型转换为MNN格式可以进一步优化模型体积和推理速度。MNN是阿里巴巴开源的一个轻量级深度学习推理引擎,专为移动设备设计,具有高性能、低内存占用的特点。

我使用MNN的转换工具将ONNX模型转换为MNN格式,并进行了FP16和INT8量化:

1
2
3
4
5
6
7
8
9

#安装MNN python包
pip install MNN

# FP16量化
mnnconvert -f ONNX --modelFile wavlm_large.onnx --MNNModel wavlm_large_fp16.mnn --fp16

# INT8量化
mnnconvert -f ONNX --modelFile wavlm_large.onnx --MNNModel wavlm_large_int8.mnn --weightQuantBits 8 --weightQuantAsymmetric

FP16量化将模型的权重从32位浮点数转换为16位浮点数,可以将模型体积减小约50%,同时保持较高的精度。INT8量化则将权重量化为8位整数,可以将模型体积进一步减小,但可能会导致一定的精度损失。

推理示例代码
为了方便用户使用优化后的模型,我分别为PyTorch、ONNX和MNN模型提供了简洁的推理示例代码。

PyTorch模型推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# 使用demo.py中的VoiceEmbeddingExtractor类
extractor = VoiceEmbeddingExtractor(ckpt_path=args.ckpt_path)
embed = extractor.extract_embedding(args.input_audio_path)
print("embedding's mean:", embed.mean())
ONNX模型推理 (infer_onnx.py)
class ONNXModel:
def __init__(self, onnx_path):
self.onnx_session = onnxruntime.InferenceSession(onnx_path)

def process(self, x):
pred_audio = self.onnx_session.run(
None,
input_feed={
"audio": x,
},
)
embed = pred_audio[0]
return embed

# 使用示例
model = ONNXModel(onnx_path=args.onnx_path)
audio = whisper.load_audio(args.input_audio_path)[None]
embed = model.process(audio)
print("embedding's mean:", embed.mean())
MNN模型推理 (infer_mnn.py)
class MNNModel:
def __init__(self, mnn_model_path: str) -> None:
input_names = ["audio"]
output_names = ["embed"]

self.chinese_bert = MNN.nn.load_module_from_file(
mnn_model_path,
input_names,
output_names,
)

def process(self, audio):
mnn_audio = MNN.expr.placeholder(audio.shape, dtype=MNN.numpy.float32)
mnn_audio.write(audio)
output = self.chinese_bert.forward([mnn_audio])
output = MNN.expr.convert(output[0], MNN.expr.NCHW)
embed = output.read()
return embed

# 使用示例
model = MNNModel(mnn_model_path=args.mnn_model_path)
audio = whisper.load_audio(args.input_audio_path)[None]
embed = model.process(audio)
print("embedding's mean:", embed.mean())

优化效果对比

通过上述优化,我们取得了显著的效果:

代码简化:推理代码从复杂的s3prl库依赖简化为不到60行的独立代码。
模型体积:
原始PyTorch模型:约1.2GB
ONNX模型:约1.1GB
MNN FP16模型:约600MB
MNN INT8模型:约300MB
部署便捷性:不再依赖torch.hub和完整的s3prl库,只需要一个.onnx或.mnn文件和少量代码即可完成部署。

总结

通过对s3prl项目中WavLM模型的优化,我们成功地简化了推理过程,减小了模型体积,提高了推理速度,并提供了更便捷的部署方式。这些优化使得WavLM模型更适合在资源受限的环境中部署,也为其他预训练音频模型的优化提供了参考。

主要贡献包括:

修改了s3prl/upstream/wavlm/expert.py文件,支持ONNX导出。
创建了独立的demo.py文件,封装了简洁的推理接口。
提供了ONNX和MNN模型的转换方法和推理示例代码。
实现了FP16和INT8量化,大幅减小了模型体积。
这些优化不仅提高了WavLM模型的实用性,也为其他研究者和开发者提供了有价值的参考。

资源链接
优化后的模型已上传至Hugging Face:

ONNX模型:https://huggingface.co/yunfengwang/wavlm-large-onnx
MNN模型:https://huggingface.co/yunfengwang/wavlm-large-mnn