wenet训练笔记
1 | # 这两个是配置环境,在path.sh下修改 /KALDI_ROOT的地址 |
step -1
- 这一步是下载数据
1 | # data是数据存放地址,自己下载了的话,就修改位置 |
step 0
- 这个是数据准备,在wenet的s0里面只调用了
local/aishell_data_prep.sh,在s1里面还用了kaldi
1 | # aishell_data_prep.sh |

text里面包括:第一列是音频的文件名,第二列是翻译

transcripts.txt和text一样的utt.list长这样,是每个音频的文件名

wav.flist是音频的存放地址wav.scp就是合并utt.list和wav.flist
- 最后把
data/local/train中的wav.scp和text放到data/train/目录下
step 1
- 这步是提取特征,这里设置了cmvn=true,这里我只跑了s0
- 在这里调用了kaldi的脚本
make_fbank.sh,以后再看这个脚本吧 - 当cmvn=true时,统计cmvn,会有两个输出的东西,
feats.scp和global_cmvn - 下面拿s0的输出结果做演示,输出的东西放在
$feat_dir/$train_set/global_cmvn里面 - 为了方便展示,删了很多
mean_stat和var_stat的值

step 2
- 这一步是创建字典
- 用了
tools/text2token.py,自动生成字典,字典长这样

step 3
- 这一步是将数据整理成想要的格式,也就是格式化
用的是
tools/format_data.sh,在s0里面,进一步调用了tools/remove_longshortdata.py- 这个脚本输入了四个参数,
nj,data/$train_set/feats.scp,data/$train_set,$dict - 注意在s0里面,第二个参数是
data/$train_set/wav.scp - s0的输出地方是
$feat_dir/$train_set/format.data.tmp,在调用了tools/remove_longshortdata.py输出到$feat_dir/$train_set/format.data - s1的输出地方是
data/$train_set/formate.data formate.data长这样,分三张图来展示- 一行一个样本
utt是音频文件名字feat是音频地址feat_shape是音频时长text是说话的内容token是分词内容tokenid是分词后,每个token的idtoken_shape是:[token的个数,字典最后一个token的id]- 字典最后一个是
eos和sos,可以在字典最后面看到



step 4
- 这个部分就是训练的步骤啦
- 这里用的是wenet的train.py,但是这个train.py的单机多卡写得比较垃圾,所以重新写了一份,这里感谢志平
- step 4的启动方式又改了下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22mkdir -p $dir
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
echo "有${num_gpus}个GPU dir: ${dir}"
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="nccl"
cp data/${train_set}/global_cmvn $dir
cmvn_opts=
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"
python -m torch.distributed.launch \
--nproc_per_node=${num_gpus} $PWD/wenet/bin/train_ddp.py \
--gpu $num_gpus \
--config $train_config \
--train_data data/$train_set/format.data \
--cv_data data/$dev_set/format.data \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \ # 这里是模型的地址
--ddp.world_size $num_gpus \
--ddp.rank 0 \
--ddp.dist_backend $dist_backend \
--num_workers 0 \ # 有时候设置为8会崩溃
$cmvn_optstrain_ddp.py的定义
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import torch
import torch.distributed as dist
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from wenet.dataset.dataset import AudioDataset, CollateFunc
from wenet.transformer.asr_model import init_asr_model
from wenet.utils.checkpoint import load_checkpoint, save_checkpoint
from wenet.utils.executor import Executor
from wenet.utils.scheduler import WarmupLR
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training')
parser.add_argument('--ddp.rank',
dest='rank',
default=0,
type=int,
help='node rank for distributed training')
parser.add_argument('--ddp.world_size',
dest='world_size',
default=-1,
type=int,
help='number of nodes for distributed training')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--ddp.init_method',
dest='init_method',
default=None,
help='ddp init method')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--cmvn', default=None, help='global cmvn file')
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Set random seed
torch.manual_seed(777)
print(args)
with open(args.config, 'r') as fin:
configs = yaml.load(fin)
distributed = args.world_size >= 1
raw_wav = configs['raw_wav']
train_collate_func = CollateFunc(**configs['collate_conf'],
raw_wav=raw_wav)
cv_collate_conf = copy.deepcopy(configs['collate_conf'])
# no augmenation on cv set
cv_collate_conf['spec_aug'] = False
cv_collate_conf['spec_sub'] = False
if raw_wav:
cv_collate_conf['feature_dither'] = 0.0
cv_collate_conf['speed_perturb'] = False
cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)
dataset_conf = configs.get('dataset_conf', {})
train_dataset = AudioDataset(args.train_data,
**dataset_conf,
raw_wav=raw_wav)
cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)
if distributed:
logging.info('training on multiple gpu, this gpu {}'.format(args.local_rank))
dist.init_process_group(args.dist_backend)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, shuffle=True)
cv_sampler = torch.utils.data.distributed.DistributedSampler(
cv_dataset, shuffle=False)
else:
train_sampler = None
cv_sampler = None
train_data_loader = DataLoader(train_dataset,
collate_fn=train_collate_func,
sampler=train_sampler,
shuffle=(train_sampler is None),
batch_size=1,
num_workers=args.num_workers)
cv_data_loader = DataLoader(cv_dataset,
collate_fn=cv_collate_func,
sampler=cv_sampler,
shuffle=False,
batch_size=1,
num_workers=args.num_workers)
if raw_wav:
input_dim = configs['collate_conf']['feature_extraction_conf'][
'mel_bins']
else:
input_dim = train_dataset.input_dim
vocab_size = train_dataset.output_dim
# Save configs to model_dir/train.yaml for inference and export
configs['input_dim'] = input_dim
configs['output_dim'] = vocab_size
configs['cmvn_file'] = args.cmvn
configs['is_json_cmvn'] = raw_wav
if args.local_rank == 0:
saved_config_path = os.path.join(args.model_dir, 'train.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs)
fout.write(data)
# Init asr model from configs
model = init_asr_model(configs)
print(model)
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
infos = load_checkpoint(model, args.checkpoint)
else:
infos = {}
start_epoch = infos.get('epoch', -1) + 1
cv_loss = infos.get('cv_loss', 0.0)
step = infos.get('step', -1)
num_epochs = configs.get('max_epoch', 100)
model_dir = args.model_dir
writer = None
if args.local_rank == 0:
os.makedirs(model_dir, exist_ok=True)
exp_id = os.path.basename(model_dir)
writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
if distributed:
# assert (torch.cuda.is_available())
# cuda model is required for nn.parallel.DistributedDataParallel
device = torch.device("cuda", args.local_rank)
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=True,device_ids=[args.local_rank],output_device=args.local_rank)
else:
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu', args.local_rank)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
final_epoch = None
configs['rank'] = args.local_rank
if start_epoch == 0 and args.local_rank == 0:
save_model_path = os.path.join(model_dir, 'init.pt')
save_checkpoint(model, save_model_path)
# Start training loop
executor.step = step
scheduler.set_step(step)
for epoch in range(start_epoch, num_epochs):
if distributed:
train_sampler.set_epoch(epoch)
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, scheduler, train_data_loader, device,
writer, configs)
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device,
configs)
if args.world_size > 1:
# all_reduce expected a sequence parameter, so we use [num_seen_utts].
num_seen_utts = torch.Tensor([num_seen_utts]).to(device)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = torch.Tensor([total_loss]).to(device)
dist.all_reduce(total_loss)
cv_loss = total_loss[0] / num_seen_utts[0]
cv_loss = cv_loss.item()
else:
cv_loss = total_loss / num_seen_utts
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
if args.local_rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
save_checkpoint(
model, save_model_path, {
'epoch': epoch,
'lr': lr,
'cv_loss': cv_loss,
'step': executor.step
})
writer.add_scalars('epoch', {'cv_loss': cv_loss, 'lr': lr}, epoch)
final_epoch = epoch
if final_epoch is not None and args.local_rank == 0:
final_model_path = os.path.join(model_dir, 'final.pt')
os.symlink('{}.pt'.format(final_epoch), final_model_path)- 这个脚本输入了四个参数,
- 最后把