Skip to content

Commit

Permalink
docs(samples): add feature importance to predict sample (#277)
Browse files Browse the repository at this point in the history
* Add feature importance to predict sample

* Fix license header

* fix: skip tensorflow linkinator - flaky

* Add bens map suggestion

* Fix lint and errors

Co-authored-by: Benjamin E. Coe <bencoe@google.com>
  • Loading branch information
2 people authored and Ace Nassri committed Nov 14, 2022
1 parent 2bec18b commit 4064a41
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
24 changes: 23 additions & 1 deletion automl/tables/predict.v1beta1.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,36 @@ async function main(
// Params is additional domain-specific parameters.
// Currently there is no additional parameters supported.
client
.predict({name: modelFullId, payload: payload, params: {}})
.predict({
name: modelFullId,
payload: payload,
params: {feature_importance: true},
})
.then(responses => {
console.log(responses);
console.log(`Prediction results:`);

for (const result of responses[0].payload) {
console.log(`Predicted class name: ${result.displayName}`);
console.log(`Predicted class score: ${result.tables.score}`);

// Get features of top importance
const featureList = result.tables.tablesModelColumnInfo.map(
columnInfo => {
return {
importance: columnInfo.featureImportance,
displayName: columnInfo.columnDisplayName,
};
}
);
// Sort features by their importance, highest importance first
featureList.sort(function(a, b) {
return b.importance - a.importance;
});

// Print top 10 important features
console.log('Features of top importance');
console.log(featureList.slice(0, 10));
}
})
.catch(err => {
Expand Down
28 changes: 15 additions & 13 deletions automl/test/automlTablesPredict.v1beta1.test.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* Copyright 2019 Google LLC
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

Expand Down Expand Up @@ -42,6 +43,7 @@ describe('Tables PredictionAPI', () => {
// Run single prediction on predictTest.csv in resource folder
const output = exec(`${cmdPredict} predict "${modelId}" "${filePath}"`);
assert.match(output, /Prediction results:/);
assert.match(output, /Features of top importance:/);
});

it.skip(`should perform batch prediction using GCS as source and
Expand Down

0 comments on commit 4064a41

Please sign in to comment.