Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import librosa | |
import requests | |
import time | |
from nemo.collections.tts.models import AudioCodecModel | |
from dataclasses import dataclass | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
class Config: | |
model_name: str = "nineninesix/kani-tts-450m-0.1-pt" | |
audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps" | |
device_map: str = "auto" | |
tokeniser_length: int = 64400 | |
start_of_text: int = 1 | |
end_of_text: int = 2 | |
max_new_tokens: int = 1200 | |
temperature: float = 1.4 | |
top_p: float = .95 | |
repetition_penalty: float = 1.1 | |
class NemoAudioPlayer: | |
def __init__(self, config, text_tokenizer_name: str = None) -> None: | |
self.conf = config | |
print(f"Loading NeMo codec model: {self.conf.audiocodec_name}") | |
# Load NeMo codec model | |
self.nemo_codec_model = AudioCodecModel.from_pretrained( | |
self.conf.audiocodec_name | |
).eval() | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Moving NeMo codec to device: {self.device}") | |
self.nemo_codec_model.to(self.device) | |
self.text_tokenizer_name = text_tokenizer_name | |
if self.text_tokenizer_name: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name) | |
# Token configuration | |
self.tokeniser_length = self.conf.tokeniser_length | |
self.start_of_text = self.conf.start_of_text | |
self.end_of_text = self.conf.end_of_text | |
self.start_of_speech = self.tokeniser_length + 1 | |
self.end_of_speech = self.tokeniser_length + 2 | |
self.start_of_human = self.tokeniser_length + 3 | |
self.end_of_human = self.tokeniser_length + 4 | |
self.start_of_ai = self.tokeniser_length + 5 | |
self.end_of_ai = self.tokeniser_length + 6 | |
self.pad_token = self.tokeniser_length + 7 | |
self.audio_tokens_start = self.tokeniser_length + 10 | |
self.codebook_size = 4032 | |
def output_validation(self, out_ids): | |
"""Validate that output contains required speech tokens""" | |
start_of_speech_flag = self.start_of_speech in out_ids | |
end_of_speech_flag = self.end_of_speech in out_ids | |
if not (start_of_speech_flag and end_of_speech_flag): | |
raise ValueError('Special speech tokens not found in output!') | |
def get_nano_codes(self, out_ids): | |
"""Extract nano codec tokens from model output""" | |
try: | |
start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item() | |
end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item() | |
except IndexError: | |
raise ValueError('Speech start/end tokens not found!') | |
if start_a_idx >= end_a_idx: | |
raise ValueError('Invalid audio codes sequence!') | |
audio_codes = out_ids[start_a_idx + 1: end_a_idx] | |
if len(audio_codes) % 4: | |
raise ValueError('Audio codes length must be multiple of 4!') | |
audio_codes = audio_codes.reshape(-1, 4) | |
# Decode audio codes | |
audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)]) | |
audio_codes = audio_codes - self.audio_tokens_start | |
if (audio_codes < 0).sum().item() > 0: | |
raise ValueError('Invalid audio tokens detected!') | |
audio_codes = audio_codes.T.unsqueeze(0) | |
len_ = torch.tensor([audio_codes.shape[-1]]) | |
return audio_codes, len_ | |
def get_text(self, out_ids): | |
"""Extract text from model output""" | |
try: | |
start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item() | |
end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item() | |
except IndexError: | |
raise ValueError('Text start/end tokens not found!') | |
txt_tokens = out_ids[start_t_idx: end_t_idx + 1] | |
text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True) | |
return text | |
def get_waveform(self, out_ids): | |
"""Convert model output to audio waveform""" | |
out_ids = out_ids.flatten() | |
# Validate output | |
self.output_validation(out_ids) | |
# Extract audio codes | |
audio_codes, len_ = self.get_nano_codes(out_ids) | |
audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device) | |
with torch.inference_mode(): | |
reconstructed_audio, _ = self.nemo_codec_model.decode( | |
tokens=audio_codes, | |
tokens_len=len_ | |
) | |
output_audio = reconstructed_audio.cpu().detach().numpy().squeeze() | |
if self.text_tokenizer_name: | |
text = self.get_text(out_ids) | |
return output_audio, text | |
else: | |
return output_audio, None | |
class KaniModel: | |
def __init__(self, config, player: NemoAudioPlayer, token: str) -> None: | |
self.conf = config | |
self.player = player | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Loading model: {self.conf.model_name}") | |
print(f"Target device: {self.device}") | |
# Load model with proper configuration | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.conf.model_name, | |
dtype=torch.bfloat16, | |
device_map=self.conf.device_map, | |
token=token, | |
trust_remote_code=True # May be needed for some models | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.conf.model_name, | |
token=token, | |
trust_remote_code=True | |
) | |
print(f"Model loaded successfully on device: {next(self.model.parameters()).device}") | |
def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]: | |
"""Prepare input tokens for the model""" | |
START_OF_HUMAN = self.player.start_of_human | |
END_OF_TEXT = self.player.end_of_text | |
END_OF_HUMAN = self.player.end_of_human | |
# Tokenize input text | |
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids | |
# Add special tokens | |
start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64) | |
end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64) | |
# Concatenate tokens | |
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) | |
attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64) | |
return modified_input_ids, attention_mask | |
def model_request( | |
self, | |
input_ids: torch.tensor, | |
attention_mask: torch.tensor, | |
t:float, | |
top_p:float, | |
rp: float, | |
max_tok: int) -> torch.tensor: | |
"""Generate tokens using the model""" | |
input_ids = input_ids.to(self.device) | |
attention_mask = attention_mask.to(self.device) | |
with torch.no_grad(): | |
generated_ids = self.model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_tok, | |
do_sample=True, | |
temperature=t, | |
top_p=top_p, | |
repetition_penalty=rp, | |
num_return_sequences=1, | |
eos_token_id=self.player.end_of_speech, | |
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id | |
) | |
return generated_ids.to('cpu') | |
def time_report(self, point_1, point_2, point_3): | |
model_request = point_2 - point_1 | |
player_time = point_3 - point_2 | |
total_time = point_3 - point_1 | |
report = f"SPEECH TOKENS: {model_request:.2f}\nCODEC: {player_time:.2f}\nTOTAL: {total_time:.2f}" | |
return report | |
def run_model(self, text: str, t: float, top_p: float, rp: float, max_tok: int): | |
"""Complete pipeline: text -> tokens -> generation -> audio""" | |
# Prepare input | |
input_ids, attention_mask = self.get_input_ids(text) | |
# Generate tokens | |
point_1 = time.time() | |
model_output = self.model_request(input_ids, attention_mask, t, top_p, rp, max_tok) | |
# Convert to audio | |
point_2 = time.time() | |
audio, _ = self.player.get_waveform(model_output) | |
point_3 = time.time() | |
return audio, text, self.time_report(point_1, point_2, point_3) | |