forked from thomasthebaud/speechLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
108 lines (92 loc) · 4.8 KB
/
Copy pathtrain.py
File metadata and controls
108 lines (92 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from trainer import SpeechLLMLightning
from dataset import InstructionalAudioDataset, MyCollator, CompositeAudioDataset, make_weighted_sampler_from_dataset
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import TQDMProgressBar
import torch.utils.data as data_utils
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import wandb
import argparse
import os
import torch
from utils import get_model_config
if __name__ == "__main__":
model_config = get_model_config()
wandb.init(project="speechllm", name=model_config['log_path'], group=model_config['group'])
logger = WandbLogger(project="speechllm", name=model_config['log_path'], group=model_config['group'])
print(model_config)
model = SpeechLLMLightning(**model_config)
tokenizer = model.llm_tokenizer
train_dataset = CompositeAudioDataset(
list_of_datasets=model_config['train_sets'],
mode='train',
random_keys_prob=0.2,
max_len=model_config['max_number_seconds'],
use_text=model_config['use_text'],
prob_text=model_config['prob_text']
)
val_dataset = CompositeAudioDataset(
list_of_datasets = model_config['dev_sets'],
mode='test',
max_len=model_config['max_number_seconds'],
max_size=model_config['max_size_per_dev_set'],
use_text=model_config['use_text'],
prob_text=model_config['prob_text']
)
print(f"Train set:{len(train_dataset)}, val set:{len(val_dataset)}, batch size:{model_config['batch_size']}")
num_workers=0 #put to 0 for debugging
my_collator = MyCollator(model_config['audio_encoder_name'], tokenizer)
# sampler = data_utils.WeightedRandomSampler(train_dataset.datasets_weights, num_samples=len(train_dataset.datasets_weights), replacement=True)
# Check whether it is generated by Composite / ConcatDataset
sampler = make_weighted_sampler_from_dataset(train_dataset)
shuffle = sampler is None # If sampler is used, shuffle must be False
train_loader = data_utils.DataLoader(
train_dataset,
batch_size=model_config['batch_size'],
shuffle=shuffle,
sampler=sampler,
collate_fn=my_collator,
num_workers=num_workers,
persistent_workers=False, pin_memory=False, prefetch_factor=None) #for debugging segmentation fault
val_loader = data_utils.DataLoader(
val_dataset,
batch_size=model_config['batch_size'],
shuffle=False,
collate_fn=my_collator,
num_workers=num_workers,
persistent_workers=False, pin_memory=False, prefetch_factor=None) #for debugging segmentation fault
if model_config['use_summaries']:
print("Using max ROUGE avg F1 score as target")
checkpoint_callback = ModelCheckpoint(
dirpath=f"checkpoints/{model_config['group']}/{model_config['model_name']}",
filename=model_config['model_name']+'epoch-{epoch}',
save_top_k=3,
mode="max",
monitor="val/summary/rouge_avg_f1",
save_last=True,
every_n_epochs=2)
early_stop_callback = EarlyStopping(monitor="val/summary/rouge_avg_f1", min_delta=0.00, patience=10, verbose=False, mode="max")
else:
print("Using min val loss as target")
checkpoint_callback = ModelCheckpoint(
dirpath=f"checkpoints/{model_config['group']}/{model_config['model_name']}",
filename=model_config['model_name']+'epoch-{epoch}',
save_top_k=3,
monitor="val/loss",
save_last=True,
every_n_epochs=2)
early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=10, verbose=False, mode="min")
trainer = Trainer(
max_epochs=model_config['total_training_epoch'],
devices=1, accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=True),#model_config['finetune_encoder']
limit_train_batches=model_config['train_batch_per_epoch'],
log_every_n_steps=100,
enable_checkpointing=True,
enable_progress_bar=True,
callbacks=[checkpoint_callback, TQDMProgressBar(refresh_rate=10*model_config['grad_accumulate_steps'])],
fast_dev_run=False, logger=logger,
accumulate_grad_batches=model_config['grad_accumulate_steps']
)
trainer.fit(model, train_loader, val_loader)