Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions speechbrain/lobes/models/huggingface_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class HuggingFaceWhisper(nn.Module):
HuggingFace hub name: e.g "openai/whisper-tiny"
save_path : str
Path (dir) of the downloaded model.
output_all_hiddens: bool (default: False)
If True, the forward function outputs the hidden states from all transformer layers of the encoder.
For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C),
where the output of the CNN output is added to the beginning.
If False, the forward function outputs the hidden states only from the last transformer layer of the encoder.
Example
-------
>>> model_hub = "openai/whisper-tiny"
Expand All @@ -65,13 +70,15 @@ def __init__(
freeze=False,
freeze_encoder=False,
output_attentions=True,
output_all_hiddens=False,
):
super().__init__()
self.sampling_rate = sampling_rate
self.encoder_only = encoder_only
self.freeze = freeze
self.freeze_encoder = freeze_encoder
self.output_attentions = output_attentions
self.output_all_hiddens = output_all_hiddens

self.tokenizer = None
# Download the tokenizer only if we are going to use the Decoder.
Expand Down Expand Up @@ -131,18 +138,29 @@ def forward(self, wav, decoder_input_ids=None):
out_encoder = self.forward_encoder(wav)
if self.encoder_only:
return out_encoder
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)

if self.output_all_hiddens:
logits, attn = self.forward_decoder(
out_encoder[-1], decoder_input_ids
)
else:
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)
return out_encoder, logits, attn
else:
if self.encoder_only:
return self.forward_encoder(wav)
else:
out_encoder = self.forward_encoder(wav)
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)
if self.output_all_hiddens:
logits, attn = self.forward_decoder(
out_encoder[-1], decoder_input_ids
)
else:
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)
return out_encoder, logits, attn

def forward_encoder(self, wav):
Expand All @@ -155,10 +173,24 @@ def forward_encoder(self, wav):

if self.freeze_encoder:
with torch.no_grad():
mel = self._get_mel(wav)
return self.model.encoder(mel).last_hidden_state
return self._get_encoder_states(wav)
else:
return self._get_encoder_states(wav)

def _get_encoder_states(self, wav):
Comment thread
Adel-Moumen marked this conversation as resolved.
"""Takes an input waveform and return its corresponding encoder states.
Returns the last hidden state of the encoder or all hidden states if
output_all_hiddens is True.
Arguments
---------
wav : torch.Tensor (signal)
A batch of audio signals to transform to features.
"""
mel = self._get_mel(wav)
if self.output_all_hiddens:
states = self.model.encoder(mel, output_hidden_states=True)
return torch.stack(states.hidden_states)
Comment thread
Hguimaraes marked this conversation as resolved.
else:
mel = self._get_mel(wav)
return self.model.encoder(mel).last_hidden_state

def _get_mel(self, wav):
Expand Down