Skip to content

Commit

Permalink
fix(core): relative path
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Aug 25, 2024
1 parent 8a503fd commit a79d297
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def load(
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
experimental=experimental,
**{
k: os.path.join(download_path, v)
for k, v in asdict(self.config.path).items()
},
)

def unload(self):
Expand Down Expand Up @@ -221,6 +225,12 @@ def interrupt(self):
@torch.no_grad()
def _load(
self,
vocos_ckpt_path: str = None,
dvae_ckpt_path: str = None,
gpt_ckpt_path: str = None,
embed_path: str = None,
decoder_ckpt_path: str = None,
tokenizer_path: str = None,
device: Optional[torch.device] = None,
compile: bool = False,
coef: Optional[str] = None,
Expand Down Expand Up @@ -250,8 +260,8 @@ def _load(
)
.eval()
)
assert self.config.path.vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(torch.load(self.config.path.vocos_ckpt_path, weights_only=True, mmap=True))
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

Expand All @@ -267,8 +277,8 @@ def _load(
.eval()
)
coef = str(dvae)
assert self.config.path.dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(torch.load(self.config.path.dvae_ckpt_path, weights_only=True, mmap=True))
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

Expand All @@ -278,7 +288,7 @@ def _load(
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(self.config.path.embed_path)
embed.from_pretrained(embed_path)
self.embed = embed
self.logger.log(logging.INFO, "embed loaded.")

Expand All @@ -291,8 +301,8 @@ def _load(
device_gpt=self.device_gpt,
logger=self.logger,
).eval()
assert self.config.path.gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(self.config.path.gpt_ckpt_path, self.config.path.embed_path, experimental=experimental)
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
self.logger.log(logging.INFO, "gpt loaded.")
Expand All @@ -312,15 +322,15 @@ def _load(
.eval()
)
coef = str(decoder)
assert self.config.path.decoder_ckpt_path, "decoder_ckpt_path should not be None"
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
torch.load(self.config.path.decoder_ckpt_path, weights_only=True, mmap=True)
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
)
self.decoder = decoder
self.logger.log(logging.INFO, "decoder loaded.")

if self.config.path.tokenizer_path:
self.tokenizer = Tokenizer(self.config.path.tokenizer_path)
if tokenizer_path:
self.tokenizer = Tokenizer(tokenizer_path)
self.logger.log(logging.INFO, "tokenizer loaded.")

self.coef = coef
Expand Down

0 comments on commit a79d297

Please sign in to comment.