Skip to content

Commit

Permalink
fix: override fit in Classifier and Regressor (#1176)
Browse files Browse the repository at this point in the history
### Summary of Changes

Override `fit` in `Classifier` and `Regressor`, so metrics can be used
on the fitted models.
  • Loading branch information
lars-reimann authored May 17, 2024
1 parent 5186b2b commit 7d79314
Show file tree
Hide file tree
Showing 24 changed files with 174 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ Learn a transformation for a set of columns in a `Table` and transform another `
**Inheritors:**

- [`Discretizer`][safeds.data.tabular.transformation.Discretizer]
- `#!sds Imputer`
- [`InvertibleTableTransformer`][safeds.data.tabular.transformation.InvertibleTableTransformer]
- [`SimpleImputer`][safeds.data.tabular.transformation.SimpleImputer]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -141,7 +141,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand Down Expand Up @@ -281,7 +281,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -340,7 +340,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -369,7 +369,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
34 changes: 23 additions & 11 deletions docs/api/safeds/ml/classical/classification/Classifier.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,27 @@ A model for classification tasks.
- [`GradientBoostingClassifier`][safeds.ml.classical.classification.GradientBoostingClassifier]
- [`KNearestNeighborsClassifier`][safeds.ml.classical.classification.KNearestNeighborsClassifier]
- [`LogisticClassifier`][safeds.ml.classical.classification.LogisticClassifier]
- `#!sds LogisticRegressionClassifier`
- [`RandomForestClassifier`][safeds.ml.classical.classification.RandomForestClassifier]
- [`SupportVectorClassifier`][safeds.ml.classical.classification.SupportVectorClassifier]
- `#!sds SupportVectorMachineClassifier`

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="10"
class Classifier sub SupervisedModel {
/**
* Create a copy of this model and fit it with the given training data.
*
* **Note:** This model is not modified.
*
* @param trainingSet The training data containing the features and target.
*
* @result fittedModel The fitted model.
*/
@Pure
fun fit(
@PythonName("training_set") trainingSet: TabularDataset
) -> fittedModel: Classifier

/**
* Summarize the classifier's metrics on the given data.
*
Expand Down Expand Up @@ -148,7 +160,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -179,7 +191,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand All @@ -204,15 +216,15 @@ Create a copy of this model and fit it with the given training data.

| Name | Type | Description |
|------|------|-------------|
| `fittedModel` | [`SupervisedModel`][safeds.ml.classical.SupervisedModel] | The fitted model. |
| `fittedModel` | [`Classifier`][safeds.ml.classical.classification.Classifier] | The fitted model. |

??? quote "Stub code in `SupervisedModel.sdsstub`"
??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="26"
```sds linenums="20"
@Pure
fun fit(
@PythonName("training_set") trainingSet: TabularDataset
) -> fittedModel: SupervisedModel
) -> fittedModel: Classifier
```

## `#!sds fun` getFeatureNames {#safeds.ml.classical.classification.Classifier.getFeatureNames data-toc-label='getFeatureNames'}
Expand Down Expand Up @@ -319,7 +331,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -378,7 +390,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -407,7 +419,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -128,7 +128,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand Down Expand Up @@ -268,7 +268,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -327,7 +327,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -356,7 +356,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -129,7 +129,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand Down Expand Up @@ -269,7 +269,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -328,7 +328,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -357,7 +357,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -116,7 +116,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand Down Expand Up @@ -256,7 +256,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -315,7 +315,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -344,7 +344,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -95,7 +95,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand Down Expand Up @@ -235,7 +235,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -294,7 +294,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -323,7 +323,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ better. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="40"
```sds linenums="54"
@Pure
fun accuracy(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>
Expand Down Expand Up @@ -141,7 +141,7 @@ classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="58"
```sds linenums="72"
@Pure
@PythonName("f1_score")
fun f1Score(
Expand Down Expand Up @@ -281,7 +281,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="78"
```sds linenums="92"
@Pure
fun precision(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -340,7 +340,7 @@ better the classifier. Results range from 0.0 to 1.0.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="97"
```sds linenums="111"
@Pure
fun recall(
@PythonName("validation_or_test_set") validationOrTestSet: union<Table, TabularDataset>,
Expand Down Expand Up @@ -369,7 +369,7 @@ Summarize the classifier's metrics on the given data.

??? quote "Stub code in `Classifier.sdsstub`"

```sds linenums="21"
```sds linenums="35"
@Pure
@PythonName("summarize_metrics")
fun summarizeMetrics(
Expand Down
Loading

0 comments on commit 7d79314

Please sign in to comment.