diff --git a/templates/titanic/.meta.yml b/templates/titanic/.meta.yml index dbe94d793..bee360391 100644 --- a/templates/titanic/.meta.yml +++ b/templates/titanic/.meta.yml @@ -1,14 +1,16 @@ title: Solving Titanic dataset with Lightning Flash author: PL team created: 2021-10-15 -updated: 2021-12-10 +updated: 2022-04-10 license: CC build: 0 description: | This is a template to show how to contribute a tutorial. requirements: - - https://github.com/PyTorchLightning/lightning-flash/archive/refs/tags/0.5.2.zip#egg=lightning-flash[tabular] - - matplotlib + - lightning-flash[tabular]>=0.7 + - torchmetrics<0.8 # collision with `pytorch-tabular` which require PL 1.3.6 + - pandas>=1.0 + - matplotlib>=3.0 - seaborn accelerator: - CPU diff --git a/templates/titanic/tutorial.py b/templates/titanic/tutorial.py index a82976fda..4eb0d691d 100644 --- a/templates/titanic/tutorial.py +++ b/templates/titanic/tutorial.py @@ -39,7 +39,7 @@ target_fields="Survived", train_file=csv_train, val_split=0.1, - batch_size=8, + batch_size=32, ) # %% [markdown] @@ -49,7 +49,7 @@ model = TabularClassifier.from_data( datamodule, learning_rate=0.1, - optimizer="Adam", + optimizer="AdamW", n_a=8, gamma=0.3, ) @@ -81,6 +81,7 @@ sns.relplot(data=metrics, kind="line") plt.gca().set_ylim([0, 1.25]) plt.gcf().set_size_inches(10, 5) +plt.grid() # %% [markdown] # ## 4. Generate predictions from a CSV @@ -88,13 +89,21 @@ # %% df_test = pd.read_csv(csv_test) -predictions = model.predict(csv_test) -print(predictions[0]) +dm = TabularClassificationData.from_data_frame( + predict_data_frame=df_test, + parameters=datamodule.parameters, + batch_size=datamodule.batch_size, +) +preds = trainer.predict(model, datamodule=dm, output="classes") +print(preds[0][:10]) # %% +import itertools # noqa: E402] + import numpy as np # noqa: E402] -assert len(df_test) == len(predictions) +predictions = list(itertools.chain(*preds)) +# assert len(df_test) == len(predictions) df_test["Survived"] = np.argmax(predictions, axis=-1) df_test.set_index("PassengerId", inplace=True)