diff --git a/flair/__init__.py b/flair/__init__.py index 1e70de006..2e984a74d 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -12,7 +12,10 @@ # global variable: device if torch.cuda.is_available(): - device = torch.device("cuda:0") + device_id = os.environ.get("FLAIR_DEVICE") + + # No need for correctness checks, torch is doing it + device = torch.device(f"cuda:{device_id}") if device_id else torch.device("cuda:0") else: device = torch.device("cpu")