Skip to content

Commit

Permalink
feat(core): add experimental argument to load() (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph authored Aug 14, 2024
1 parent d93ed8d commit a27357f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def load(
coef: Optional[torch.Tensor] = None,
use_flash_attn=False,
use_vllm=False,
experimental: bool = False,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
if download_path is None:
Expand All @@ -137,6 +138,7 @@ def load(
coef=coef,
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
experimental=experimental,
**{
k: os.path.join(download_path, v)
for k, v in asdict(self.config.path).items()
Expand Down Expand Up @@ -233,9 +235,10 @@ def _load(
coef: Optional[str] = None,
use_flash_attn=False,
use_vllm=False,
experimental: bool = False,
):
if device is None:
device = select_device()
device = select_device(experimental=experimental)
self.logger.info("use device %s", str(device))
self.device = device
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
Expand Down Expand Up @@ -287,7 +290,7 @@ def _load(
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(gpt_ckpt_path)
gpt.from_pretrained(gpt_ckpt_path, experimental=experimental)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt

Expand Down

0 comments on commit a27357f

Please sign in to comment.