From a79d297fe4cbe7df12238e4c990d22f65bea6b96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 25 Aug 2024 11:16:03 +0800 Subject: [PATCH] fix(core): relative path --- ChatTTS/core.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index ad049b488..3606adff5 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -139,6 +139,10 @@ def load( use_flash_attn=use_flash_attn, use_vllm=use_vllm, experimental=experimental, + **{ + k: os.path.join(download_path, v) + for k, v in asdict(self.config.path).items() + }, ) def unload(self): @@ -221,6 +225,12 @@ def interrupt(self): @torch.no_grad() def _load( self, + vocos_ckpt_path: str = None, + dvae_ckpt_path: str = None, + gpt_ckpt_path: str = None, + embed_path: str = None, + decoder_ckpt_path: str = None, + tokenizer_path: str = None, device: Optional[torch.device] = None, compile: bool = False, coef: Optional[str] = None, @@ -250,8 +260,8 @@ def _load( ) .eval() ) - assert self.config.path.vocos_ckpt_path, "vocos_ckpt_path should not be None" - vocos.load_state_dict(torch.load(self.config.path.vocos_ckpt_path, weights_only=True, mmap=True)) + assert vocos_ckpt_path, "vocos_ckpt_path should not be None" + vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True)) self.vocos = vocos self.logger.log(logging.INFO, "vocos loaded.") @@ -267,8 +277,8 @@ def _load( .eval() ) coef = str(dvae) - assert self.config.path.dvae_ckpt_path, "dvae_ckpt_path should not be None" - dvae.load_state_dict(torch.load(self.config.path.dvae_ckpt_path, weights_only=True, mmap=True)) + assert dvae_ckpt_path, "dvae_ckpt_path should not be None" + dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True)) self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") @@ -278,7 +288,7 @@ def _load( self.config.embed.num_text_tokens, self.config.embed.num_vq, ) - embed.from_pretrained(self.config.path.embed_path) + embed.from_pretrained(embed_path) self.embed = embed self.logger.log(logging.INFO, "embed loaded.") @@ -291,8 +301,8 @@ def _load( device_gpt=self.device_gpt, logger=self.logger, ).eval() - assert self.config.path.gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.from_pretrained(self.config.path.gpt_ckpt_path, self.config.path.embed_path, experimental=experimental) + assert gpt_ckpt_path, "gpt_ckpt_path should not be None" + gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental) gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt self.logger.log(logging.INFO, "gpt loaded.") @@ -312,15 +322,15 @@ def _load( .eval() ) coef = str(decoder) - assert self.config.path.decoder_ckpt_path, "decoder_ckpt_path should not be None" + assert decoder_ckpt_path, "decoder_ckpt_path should not be None" decoder.load_state_dict( - torch.load(self.config.path.decoder_ckpt_path, weights_only=True, mmap=True) + torch.load(decoder_ckpt_path, weights_only=True, mmap=True) ) self.decoder = decoder self.logger.log(logging.INFO, "decoder loaded.") - if self.config.path.tokenizer_path: - self.tokenizer = Tokenizer(self.config.path.tokenizer_path) + if tokenizer_path: + self.tokenizer = Tokenizer(tokenizer_path) self.logger.log(logging.INFO, "tokenizer loaded.") self.coef = coef