diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts index 78af0a862d302d..93f8152efa19bd 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts @@ -8,6 +8,7 @@ import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types'; import { BUILT_IN_MODEL_TAG, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils'; +import { MlModel, MlModelDeploymentState } from '../types/ml'; import { MlInferencePipeline, TrainedModelState } from '../types/pipelines'; import { @@ -19,7 +20,7 @@ import { parseModelStateReasonFromStats, } from '.'; -const mockModel: MlTrainedModelConfig = { +const mockTrainedModel: MlTrainedModelConfig = { inference_config: { ner: {}, }, @@ -32,8 +33,27 @@ const mockModel: MlTrainedModelConfig = { version: '1', }; +const mockModel: MlModel = { + modelId: 'model_1', + type: 'ner', + title: 'Model 1', + description: 'Model 1 description', + licenseType: 'elastic', + modelDetailsPageUrl: 'https://my-model.ai', + deploymentState: MlModelDeploymentState.NotDeployed, + startTime: 0, + targetAllocationCount: 0, + nodeAllocationCount: 0, + threadsPerAllocation: 0, + isPlaceholder: false, + hasStats: false, + types: ['pytorch', 'ner'], + inputFieldNames: ['title'], + version: '1', +}; + describe('getMlModelTypesForModelConfig lib function', () => { - const builtInMockModel: MlTrainedModelConfig = { + const builtInMockTrainedModel: MlTrainedModelConfig = { inference_config: { text_classification: {}, }, @@ -47,13 +67,13 @@ describe('getMlModelTypesForModelConfig lib function', () => { it('should return the model type and inference config type', () => { const expected = ['pytorch', 'ner']; - const response = getMlModelTypesForModelConfig(mockModel); + const response = getMlModelTypesForModelConfig(mockTrainedModel); expect(response.sort()).toEqual(expected.sort()); }); it('should include the built in type', () => { const expected = ['lang_ident', 'text_classification', BUILT_IN_MODEL_TAG]; - const response = getMlModelTypesForModelConfig(builtInMockModel); + const response = getMlModelTypesForModelConfig(builtInMockTrainedModel); expect(response.sort()).toEqual(expected.sort()); }); }); @@ -71,9 +91,9 @@ describe('generateMlInferencePipelineBody lib function', () => { { inference: { field_map: { - 'my-source-field': 'MODEL_INPUT_FIELD', + 'my-source-field': 'title', }, - model_id: 'test_id', + model_id: 'model_1', on_failure: [ { append: { @@ -154,21 +174,21 @@ describe('generateMlInferencePipelineBody lib function', () => { { inference: expect.objectContaining({ field_map: { - 'my-source-field1': 'MODEL_INPUT_FIELD', + 'my-source-field1': 'title', }, }), }, { inference: expect.objectContaining({ field_map: { - 'my-source-field2': 'MODEL_INPUT_FIELD', + 'my-source-field2': 'title', }, }), }, { inference: expect.objectContaining({ field_map: { - 'my-source-field3': 'MODEL_INPUT_FIELD', + 'my-source-field3': 'title', }, }), }, diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts index 5f56c1105b297c..fa16dd29f83b1b 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts @@ -18,6 +18,8 @@ import { BUILT_IN_MODEL_TAG, } from '@kbn/ml-trained-models-utils'; +import { MlModel } from '../types/ml'; + import { MlInferencePipeline, CreateMLInferencePipeline, @@ -33,7 +35,7 @@ export interface MlInferencePipelineParams { description?: string; fieldMappings: FieldMapping[]; inferenceConfig?: InferencePipelineInferenceConfig; - model: MlTrainedModelConfig; + model: MlModel; pipelineName: string; } @@ -90,7 +92,7 @@ export const generateMlInferencePipelineBody = ({ model_version: model.version, pipeline: pipelineName, processed_timestamp: '{{{ _ingest.timestamp }}}', - types: getMlModelTypesForModelConfig(model), + types: model.types, }, ], }, @@ -104,19 +106,19 @@ export const getInferenceProcessor = ( sourceField: string, targetField: string, inferenceConfig: InferencePipelineInferenceConfig | undefined, - model: MlTrainedModelConfig, + model: MlModel, pipelineName: string ): IngestInferenceProcessor => { // If model returned no input field, insert a placeholder const modelInputField = - model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD'; + model.inputFieldNames.length > 0 ? model.inputFieldNames[0] : 'MODEL_INPUT_FIELD'; return { field_map: { [sourceField]: modelInputField, }, inference_config: inferenceConfig, - model_id: model.model_id, + model_id: model.modelId, on_failure: [ { append: { diff --git a/x-pack/plugins/enterprise_search/common/types/ml.ts b/x-pack/plugins/enterprise_search/common/types/ml.ts index 894ffa6f0726bb..2f40475535107d 100644 --- a/x-pack/plugins/enterprise_search/common/types/ml.ts +++ b/x-pack/plugins/enterprise_search/common/types/ml.ts @@ -27,6 +27,10 @@ export interface MlModel { modelId: string; /** Model inference type, e.g. ner, text_classification */ type: string; + /** Type-related tags: model type (e.g. pytorch), inference type, built-in tag */ + types: string[]; + /** Field names in inference input configuration */ + inputFieldNames: string[]; title: string; description?: string; licenseType?: string; @@ -44,4 +48,5 @@ export interface MlModel { isPlaceholder: boolean; /** Does this model have deployment stats? */ hasStats: boolean; + version?: string; } diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.test.ts index 6d66ed5704721a..869bd9273ac09f 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.test.ts @@ -30,8 +30,11 @@ const DEFAULT_VALUES: CachedFetchModelsApiLogicValues = { const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [ { modelId: 'model_1', - title: 'Model 1', type: 'ner', + title: 'Model 1', + description: 'Model 1 description', + licenseType: 'elastic', + modelDetailsPageUrl: 'https://my-model.ai', deploymentState: MlModelDeploymentState.NotDeployed, startTime: 0, targetAllocationCount: 0, @@ -39,6 +42,9 @@ const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [ threadsPerAllocation: 0, isPlaceholder: false, hasStats: false, + types: ['pytorch', 'ner'], + inputFieldNames: ['title'], + version: '1', }, ]; const FETCH_MODELS_API_ERROR_RESPONSE = { diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.ts index d65af6ec2fcf41..a26dbada96c08c 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.ts @@ -18,6 +18,8 @@ import { FetchModelsApiLogic, FetchModelsApiResponse } from './fetch_models_api_ const FETCH_MODELS_POLLING_DURATION = 5000; // 5 seconds const FETCH_MODELS_POLLING_DURATION_ON_FAILURE = 30000; // 30 seconds +export type { FetchModelsApiResponse } from './fetch_models_api_logic'; + export interface CachedFetchModlesApiLogicActions { apiError: Actions<{}, FetchModelsApiResponse>['apiError']; apiReset: Actions<{}, FetchModelsApiResponse>['apiReset']; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_model_stats_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_model_stats_logic.test.ts deleted file mode 100644 index 4bcea4ac4e83ae..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_model_stats_logic.test.ts +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { mockHttpValues } from '../../../__mocks__/kea_logic'; -import { mlModelStats } from '../../__mocks__/ml_models.mock'; - -import { getMLModelsStats } from './ml_model_stats_logic'; - -describe('MLModelsApiLogic', () => { - const { http } = mockHttpValues; - beforeEach(() => { - jest.clearAllMocks(); - }); - describe('getMLModelsStats', () => { - it('calls the ml api', async () => { - http.get.mockResolvedValue(mlModelStats); - const result = await getMLModelsStats(); - expect(http.get).toHaveBeenCalledWith('/internal/ml/trained_models/_stats', { version: '1' }); - expect(result).toEqual(mlModelStats); - }); - }); -}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_model_stats_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_model_stats_logic.ts deleted file mode 100644 index d8bc341fcb6c31..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_model_stats_logic.ts +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types'; - -import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic'; -import { HttpLogic } from '../../../shared/http'; - -export type GetMlModelsStatsArgs = undefined; - -export interface GetMlModelsStatsResponse { - count: number; - trained_model_stats: MlTrainedModelStats[]; -} - -export const getMLModelsStats = async () => { - return await HttpLogic.values.http.get( - '/internal/ml/trained_models/_stats', - { version: '1' } - ); -}; - -export const MLModelsStatsApiLogic = createApiLogic( - ['ml_models_stats_api_logic'], - getMLModelsStats, - { - clearFlashMessagesOnMakeRequest: false, - showErrorFlash: false, - } -); - -export type MLModelsStatsApiLogicActions = Actions; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_models_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_models_logic.test.ts deleted file mode 100644 index cffdc08dfd2efd..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_models_logic.test.ts +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { mockHttpValues } from '../../../__mocks__/kea_logic'; -import { mlModels } from '../../__mocks__/ml_models.mock'; - -import { getMLModels } from './ml_models_logic'; - -describe('MLModelsApiLogic', () => { - const { http } = mockHttpValues; - beforeEach(() => { - jest.clearAllMocks(); - }); - describe('getMLModels', () => { - it('calls the ml api', async () => { - http.get.mockResolvedValue(mlModels); - const result = await getMLModels(); - expect(http.get).toHaveBeenCalledWith('/internal/ml/trained_models', { - query: { size: 1000, with_pipelines: true }, - version: '1', - }); - expect(result).toEqual(mlModels); - }); - }); -}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_models_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_models_logic.ts deleted file mode 100644 index 9cec020b8d7826..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_models_logic.ts +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models'; - -import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic'; -import { HttpLogic } from '../../../shared/http'; - -export type GetMlModelsArgs = number | undefined; - -export type GetMlModelsResponse = TrainedModelConfigResponse[]; - -export const getMLModels = async (size: GetMlModelsArgs = 1000) => { - return await HttpLogic.values.http.get( - '/internal/ml/trained_models', - { - query: { size, with_pipelines: true }, - version: '1', - } - ); -}; - -export const MLModelsApiLogic = createApiLogic(['ml_models_api_logic'], getMLModels, { - clearFlashMessagesOnMakeRequest: false, - showErrorFlash: false, -}); - -export type MLModelsApiLogicActions = Actions; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_trained_models_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_trained_models_logic.test.ts deleted file mode 100644 index 6a1a4e0e512dfa..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_trained_models_logic.test.ts +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { LogicMounter } from '../../../__mocks__/kea_logic'; -import { mlModels, mlModelStats } from '../../__mocks__/ml_models.mock'; - -import { HttpError, Status } from '../../../../../common/types/api'; - -import { MLModelsStatsApiLogic } from './ml_model_stats_logic'; -import { MLModelsApiLogic } from './ml_models_logic'; -import { TrainedModelsApiLogic, TrainedModelsApiLogicValues } from './ml_trained_models_logic'; - -const DEFAULT_VALUES: TrainedModelsApiLogicValues = { - error: null, - status: Status.IDLE, - data: null, - // models - modelsApiStatus: { - status: Status.IDLE, - }, - modelsData: undefined, - modelsApiError: undefined, - modelsStatus: Status.IDLE, - // stats - modelStatsApiStatus: { - status: Status.IDLE, - }, - modelStatsData: undefined, - modelsStatsApiError: undefined, - modelStatsStatus: Status.IDLE, -}; - -describe('TrainedModelsApiLogic', () => { - const { mount } = new LogicMounter(TrainedModelsApiLogic); - const { mount: mountMLModelsApiLogic } = new LogicMounter(MLModelsApiLogic); - const { mount: mountMLModelsStatsApiLogic } = new LogicMounter(MLModelsStatsApiLogic); - - beforeEach(() => { - jest.clearAllMocks(); - - mountMLModelsApiLogic(); - mountMLModelsStatsApiLogic(); - mount(); - }); - - it('has default values', () => { - expect(TrainedModelsApiLogic.values).toEqual(DEFAULT_VALUES); - }); - describe('selectors', () => { - describe('data', () => { - it('returns combined trained models', () => { - MLModelsApiLogic.actions.apiSuccess(mlModels); - MLModelsStatsApiLogic.actions.apiSuccess(mlModelStats); - - expect(TrainedModelsApiLogic.values.data).toEqual([ - { - ...mlModels[0], - ...mlModelStats.trained_model_stats[0], - }, - { - ...mlModels[1], - ...mlModelStats.trained_model_stats[1], - }, - { - ...mlModels[2], - ...mlModelStats.trained_model_stats[2], - }, - ]); - }); - it('returns just models if stats not available', () => { - MLModelsApiLogic.actions.apiSuccess(mlModels); - - expect(TrainedModelsApiLogic.values.data).toEqual(mlModels); - }); - it('returns null trained models even with stats if models missing', () => { - MLModelsStatsApiLogic.actions.apiSuccess(mlModelStats); - - expect(TrainedModelsApiLogic.values.data).toEqual(null); - }); - }); - describe('error', () => { - const modelError: HttpError = { - body: { - error: 'Model Error', - statusCode: 400, - }, - fetchOptions: {}, - request: {}, - } as HttpError; - const statsError: HttpError = { - body: { - error: 'Stats Error', - statusCode: 500, - }, - fetchOptions: {}, - request: {}, - } as HttpError; - - it('returns null with no errors', () => { - MLModelsApiLogic.actions.apiSuccess(mlModels); - MLModelsStatsApiLogic.actions.apiSuccess(mlModelStats); - - expect(TrainedModelsApiLogic.values.error).toBeNull(); - }); - it('models error', () => { - MLModelsApiLogic.actions.apiError(modelError); - - expect(TrainedModelsApiLogic.values.error).toBe(modelError); - }); - it('stats error', () => { - MLModelsStatsApiLogic.actions.apiError(statsError); - - expect(TrainedModelsApiLogic.values.error).toBe(statsError); - }); - it('prefers models error if both api calls fail', () => { - MLModelsApiLogic.actions.apiError(modelError); - MLModelsStatsApiLogic.actions.apiError(statsError); - - expect(TrainedModelsApiLogic.values.error).toBe(modelError); - }); - }); - describe('status', () => { - it('returns matching status for both calls', () => { - MLModelsApiLogic.actions.apiSuccess(mlModels); - MLModelsStatsApiLogic.actions.apiSuccess(mlModelStats); - - expect(TrainedModelsApiLogic.values.status).toEqual(Status.SUCCESS); - }); - it('returns models status when its lower', () => { - MLModelsStatsApiLogic.actions.apiSuccess(mlModelStats); - - expect(TrainedModelsApiLogic.values.status).toEqual(Status.IDLE); - }); - it('returns stats status when its lower', () => { - MLModelsApiLogic.actions.apiSuccess(mlModels); - - expect(TrainedModelsApiLogic.values.status).toEqual(Status.IDLE); - }); - it('returns error status if one api call fails', () => { - MLModelsApiLogic.actions.apiSuccess(mlModels); - MLModelsStatsApiLogic.actions.apiError({ - body: { - error: 'Stats Error', - statusCode: 500, - }, - fetchOptions: {}, - request: {}, - } as HttpError); - - expect(TrainedModelsApiLogic.values.status).toEqual(Status.ERROR); - }); - }); - }); - describe('actions', () => { - it('makeRequest fetches models and stats', () => { - jest.spyOn(TrainedModelsApiLogic.actions, 'makeGetModelsRequest'); - jest.spyOn(TrainedModelsApiLogic.actions, 'makeGetModelsStatsRequest'); - - TrainedModelsApiLogic.actions.makeRequest(undefined); - - expect(TrainedModelsApiLogic.actions.makeGetModelsRequest).toHaveBeenCalledTimes(1); - expect(TrainedModelsApiLogic.actions.makeGetModelsStatsRequest).toHaveBeenCalledTimes(1); - }); - it('apiReset resets both api logics', () => { - jest.spyOn(TrainedModelsApiLogic.actions, 'getModelsApiReset'); - jest.spyOn(TrainedModelsApiLogic.actions, 'getModelsStatsApiReset'); - - TrainedModelsApiLogic.actions.apiReset(); - - expect(TrainedModelsApiLogic.actions.getModelsApiReset).toHaveBeenCalledTimes(1); - expect(TrainedModelsApiLogic.actions.getModelsStatsApiReset).toHaveBeenCalledTimes(1); - }); - }); -}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_trained_models_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_trained_models_logic.ts deleted file mode 100644 index d36a80df6af6ab..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/ml_trained_models_logic.ts +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { kea, MakeLogicType } from 'kea'; - -import { MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types'; -import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models'; - -import { ApiStatus, Status, HttpError } from '../../../../../common/types/api'; -import { Actions } from '../../../shared/api_logic/create_api_logic'; - -import { - GetMlModelsStatsResponse, - MLModelsStatsApiLogic, - MLModelsStatsApiLogicActions, -} from './ml_model_stats_logic'; -import { GetMlModelsResponse, MLModelsApiLogic, MLModelsApiLogicActions } from './ml_models_logic'; - -export type TrainedModel = TrainedModelConfigResponse & Partial; - -export type TrainedModelsApiLogicActions = Actions & { - getModelsApiError: MLModelsApiLogicActions['apiError']; - getModelsApiReset: MLModelsApiLogicActions['apiReset']; - getModelsApiSuccess: MLModelsApiLogicActions['apiSuccess']; - getModelsStatsApiError: MLModelsStatsApiLogicActions['apiError']; - getModelsStatsApiReset: MLModelsStatsApiLogicActions['apiReset']; - getModelsStatsApiSuccess: MLModelsStatsApiLogicActions['apiSuccess']; - makeGetModelsRequest: MLModelsApiLogicActions['makeRequest']; - makeGetModelsStatsRequest: MLModelsStatsApiLogicActions['makeRequest']; -}; -export interface TrainedModelsApiLogicValues { - error: HttpError | null; - status: Status; - data: TrainedModel[] | null; - // models - modelsApiStatus: ApiStatus; - modelsData: GetMlModelsResponse | undefined; - modelsApiError?: HttpError; - modelsStatus: Status; - // stats - modelStatsApiStatus: ApiStatus; - modelStatsData: GetMlModelsStatsResponse | undefined; - modelsStatsApiError?: HttpError; - modelStatsStatus: Status; -} - -export const TrainedModelsApiLogic = kea< - MakeLogicType ->({ - actions: { - apiError: (error) => error, - apiReset: true, - apiSuccess: (result) => result, - makeRequest: () => undefined, - }, - connect: { - actions: [ - MLModelsApiLogic, - [ - 'apiError as getModelsApiError', - 'apiReset as getModelsApiReset', - 'apiSuccess as getModelsApiSuccess', - 'makeRequest as makeGetModelsRequest', - ], - MLModelsStatsApiLogic, - [ - 'apiError as getModelsStatsApiError', - 'apiReset as getModelsStatsApiReset', - 'apiSuccess as getModelsStatsApiSuccess', - 'makeRequest as makeGetModelsStatsRequest', - ], - ], - values: [ - MLModelsApiLogic, - [ - 'apiStatus as modelsApiStatus', - 'error as modelsApiError', - 'status as modelsStatus', - 'data as modelsData', - ], - MLModelsStatsApiLogic, - [ - 'apiStatus as modelStatsApiStatus', - 'error as modelsStatsApiError', - 'status as modelStatsStatus', - 'data as modelStatsData', - ], - ], - }, - listeners: ({ actions, values }) => ({ - getModelsApiError: (error) => { - actions.apiError(error); - }, - getModelsApiSuccess: () => { - if (!values.data) return; - actions.apiSuccess(values.data); - }, - getModelsStatsApiError: (error) => { - if (values.modelsApiError) return; - actions.apiError(error); - }, - getModelsStatsApiSuccess: () => { - if (!values.data) return; - actions.apiSuccess(values.data); - }, - apiReset: () => { - actions.getModelsApiReset(); - actions.getModelsStatsApiReset(); - }, - makeRequest: () => { - actions.makeGetModelsRequest(undefined); - actions.makeGetModelsStatsRequest(undefined); - }, - }), - path: ['enterprise_search', 'api', 'ml_trained_models_api_logic'], - selectors: ({ selectors }) => ({ - data: [ - () => [selectors.modelsData, selectors.modelStatsData], - ( - modelsData: TrainedModelsApiLogicValues['modelsData'], - modelStatsData: TrainedModelsApiLogicValues['modelStatsData'] - ): TrainedModel[] | null => { - if (!modelsData) return null; - if (!modelStatsData) return modelsData; - const statsMap: Record = - modelStatsData.trained_model_stats.reduce((map, value) => { - if (value.model_id) { - map[value.model_id] = value; - } - return map; - }, {} as Record); - return modelsData.map((modelConfig) => { - const modelStats = statsMap[modelConfig.model_id]; - return { - ...modelConfig, - ...(modelStats ?? {}), - }; - }); - }, - ], - error: [ - () => [selectors.modelsApiStatus, selectors.modelStatsApiStatus], - ( - modelsApiStatus: TrainedModelsApiLogicValues['modelsApiStatus'], - modelStatsApiStatus: TrainedModelsApiLogicValues['modelStatsApiStatus'] - ) => { - if (modelsApiStatus.error) return modelsApiStatus.error; - if (modelStatsApiStatus.error) return modelStatsApiStatus.error; - return null; - }, - ], - status: [ - () => [selectors.modelsApiStatus, selectors.modelStatsApiStatus], - ( - modelsApiStatus: TrainedModelsApiLogicValues['modelsApiStatus'], - modelStatsApiStatus: TrainedModelsApiLogicValues['modelStatsApiStatus'] - ) => { - if (modelsApiStatus.status === modelStatsApiStatus.status) return modelsApiStatus.status; - if (modelsApiStatus.status === Status.ERROR || modelStatsApiStatus.status === Status.ERROR) - return Status.ERROR; - if (modelsApiStatus.status < modelStatsApiStatus.status) return modelsApiStatus.status; - return modelStatsApiStatus.status; - }, - ], - }), -}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.test.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.test.tsx index 68bf7fc48e7dd8..09789e34c963bd 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.test.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.test.tsx @@ -29,7 +29,6 @@ import { import { ConfigureFields } from './configure_fields'; import { ConfigurePipeline } from './configure_pipeline'; import { EMPTY_PIPELINE_CONFIGURATION } from './ml_inference_logic'; -import { NoModelsPanel } from './no_models'; import { ReviewPipeline } from './review_pipeline'; import { TestPipeline } from './test_pipeline'; import { AddInferencePipelineSteps } from './types'; @@ -82,11 +81,6 @@ describe('AddInferencePipelineFlyout', () => { const wrapper = shallow(); expect(wrapper.find(EuiLoadingSpinner)).toHaveLength(1); }); - it('renders no models panel when there are no models', () => { - setMockValues({ ...DEFAULT_VALUES, supportedMLModels: [] }); - const wrapper = shallow(); - expect(wrapper.find(NoModelsPanel)).toHaveLength(1); - }); it('renders AddInferencePipelineHorizontalSteps', () => { const wrapper = shallow(); expect(wrapper.find(AddInferencePipelineHorizontalSteps)).toHaveLength(1); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.tsx index 654cdafad35eb8..57aa7ab4674887 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_inference_pipeline_flyout.tsx @@ -41,7 +41,6 @@ import { IndexViewLogic } from '../../index_view_logic'; import { ConfigureFields } from './configure_fields'; import { ConfigurePipeline } from './configure_pipeline'; import { MLInferenceLogic } from './ml_inference_logic'; -import { NoModelsPanel } from './no_models'; import { ReviewPipeline } from './review_pipeline'; import { TestPipeline } from './test_pipeline'; import { AddInferencePipelineSteps } from './types'; @@ -54,9 +53,15 @@ export interface AddInferencePipelineFlyoutProps { export const AddInferencePipelineFlyout = (props: AddInferencePipelineFlyoutProps) => { const { indexName } = useValues(IndexNameLogic); - const { setIndexName } = useActions(MLInferenceLogic); + const { setIndexName, makeMlInferencePipelinesRequest, startPollingModels, makeMappingRequest } = + useActions(MLInferenceLogic); useEffect(() => { setIndexName(indexName); + + // Trigger fetching of initial data: existing ML pipelines, available models, index mapping + makeMlInferencePipelinesRequest(undefined); + startPollingModels(); + makeMappingRequest({ indexName }); }, [indexName]); return ( @@ -82,7 +87,6 @@ export const AddInferencePipelineContent = ({ onClose }: AddInferencePipelineFly const { ingestionMethod } = useValues(IndexViewLogic); const { createErrors, - supportedMLModels, isLoading, addInferencePipelineModal: { step }, } = useValues(MLInferenceLogic); @@ -103,9 +107,6 @@ export const AddInferencePipelineContent = ({ onClose }: AddInferencePipelineFly ); } - if (supportedMLModels.length === 0) { - return ; - } return ( <> diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx index 7b3fdf4353d99b..244f7a3e8a200c 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/inference_config.tsx @@ -14,7 +14,6 @@ import { i18n } from '@kbn/i18n'; import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils'; -import { getMlModelTypesForModelConfig } from '../../../../../../../common/ml_inference_pipeline'; import { getMLType } from '../../../shared/ml_inference/utils'; import { MLInferenceLogic } from './ml_inference_logic'; @@ -23,10 +22,10 @@ import { ZeroShotClassificationInferenceConfiguration } from './zero_shot_infere export const InferenceConfiguration: React.FC = () => { const { addInferencePipelineModal: { configuration }, - selectedMLModel, + selectedModel, } = useValues(MLInferenceLogic); - if (!selectedMLModel || configuration.existingPipeline) return null; - const modelType = getMLType(getMlModelTypesForModelConfig(selectedMLModel)); + if (!selectedModel || configuration.existingPipeline) return null; + const modelType = getMLType(selectedModel.types); switch (modelType) { case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION: return ( diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts index e12366f42f3ef0..ae3cc237c67a5b 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts @@ -6,16 +6,16 @@ */ import { LogicMounter } from '../../../../../__mocks__/kea_logic'; -import { nerModel, textExpansionModel } from '../../../../__mocks__/ml_models.mock'; import { HttpResponse } from '@kbn/core/public'; -import { ErrorResponse } from '../../../../../../../common/types/api'; +import { ErrorResponse, Status } from '../../../../../../../common/types/api'; +import { MlModel, MlModelDeploymentState } from '../../../../../../../common/types/ml'; import { TrainedModelState } from '../../../../../../../common/types/pipelines'; import { GetDocumentsApiLogic } from '../../../../api/documents/get_document_logic'; import { MappingsApiLogic } from '../../../../api/mappings/mappings_logic'; -import { MLModelsApiLogic } from '../../../../api/ml_models/ml_models_logic'; +import { CachedFetchModelsApiLogic } from '../../../../api/ml_models/cached_fetch_models_api_logic'; import { StartTextExpansionModelApiLogic } from '../../../../api/ml_models/text_expansion/start_text_expansion_model_api_logic'; import { AttachMlInferencePipelineApiLogic } from '../../../../api/pipelines/attach_ml_inference_pipeline'; import { CreateMlInferencePipelineApiLogic } from '../../../../api/pipelines/create_ml_inference_pipeline'; @@ -50,6 +50,7 @@ const DEFAULT_VALUES: MLInferenceProcessorsValues = { index: null, isConfigureStepValid: false, isLoading: true, + isModelsInitialLoading: false, isPipelineDataValid: false, isTextExpansionModelSelected: false, mappingData: undefined, @@ -57,17 +58,38 @@ const DEFAULT_VALUES: MLInferenceProcessorsValues = { mlInferencePipeline: undefined, mlInferencePipelineProcessors: undefined, mlInferencePipelinesData: undefined, - mlModelsData: null, - mlModelsStatus: 0, - selectedMLModel: null, + modelsData: undefined, + modelsStatus: 0, + selectableModels: [], + selectedModel: undefined, sourceFields: undefined, - supportedMLModels: [], }; +const MODELS: MlModel[] = [ + { + modelId: 'model_1', + type: 'ner', + title: 'Model 1', + description: 'Model 1 description', + licenseType: 'elastic', + modelDetailsPageUrl: 'https://my-model.ai', + deploymentState: MlModelDeploymentState.NotDeployed, + startTime: 0, + targetAllocationCount: 0, + nodeAllocationCount: 0, + threadsPerAllocation: 0, + isPlaceholder: false, + hasStats: false, + types: ['pytorch', 'ner'], + inputFieldNames: ['title'], + version: '1', + }, +]; + describe('MlInferenceLogic', () => { const { mount } = new LogicMounter(MLInferenceLogic); const { mount: mountMappingApiLogic } = new LogicMounter(MappingsApiLogic); - const { mount: mountMLModelsApiLogic } = new LogicMounter(MLModelsApiLogic); + const { mount: mountCachedFetchModelsApiLogic } = new LogicMounter(CachedFetchModelsApiLogic); const { mount: mountSimulateExistingMlInterfacePipelineApiLogic } = new LogicMounter( SimulateExistingMlInterfacePipelineApiLogic ); @@ -92,7 +114,7 @@ describe('MlInferenceLogic', () => { beforeEach(() => { jest.clearAllMocks(); mountMappingApiLogic(); - mountMLModelsApiLogic(); + mountCachedFetchModelsApiLogic(); mountFetchMlInferencePipelineProcessorsApiLogic(); mountFetchMlInferencePipelinesApiLogic(); mountSimulateExistingMlInterfacePipelineApiLogic(); @@ -105,7 +127,13 @@ describe('MlInferenceLogic', () => { }); it('has expected default values', () => { - expect(MLInferenceLogic.values).toEqual(DEFAULT_VALUES); + CachedFetchModelsApiLogic.actions.apiSuccess(MODELS); + expect(MLInferenceLogic.values).toEqual({ + ...DEFAULT_VALUES, + modelsData: MODELS, // Populated by afterMount hook + modelsStatus: Status.SUCCESS, + selectableModels: MODELS, + }); }); describe('actions', () => { @@ -184,6 +212,7 @@ describe('MlInferenceLogic', () => { describe('selectors', () => { describe('existingInferencePipelines', () => { beforeEach(() => { + CachedFetchModelsApiLogic.actions.apiSuccess(MODELS); MappingsApiLogic.actions.apiSuccess({ mappings: { properties: { @@ -206,7 +235,7 @@ describe('MlInferenceLogic', () => { field_map: { body: 'text_field', }, - model_id: 'test-model', + model_id: MODELS[0].modelId, target_field: 'ml.inference.test-field', }, }, @@ -218,8 +247,8 @@ describe('MlInferenceLogic', () => { expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([ { disabled: false, - modelId: 'test-model', - modelType: '', + modelId: MODELS[0].modelId, + modelType: 'ner', pipelineName: 'unit-test', sourceFields: ['body'], indexFields: ['body'], @@ -235,7 +264,7 @@ describe('MlInferenceLogic', () => { field_map: { title: 'text_field', // Does not exist in index }, - model_id: 'test-model', + model_id: MODELS[0].modelId, target_field: 'ml.inference.title', }, }, @@ -244,7 +273,7 @@ describe('MlInferenceLogic', () => { field_map: { body: 'text_field', // Exists in index }, - model_id: 'test-model', + model_id: MODELS[0].modelId, target_field: 'ml.inference.body', }, }, @@ -253,7 +282,7 @@ describe('MlInferenceLogic', () => { field_map: { body_content: 'text_field', // Does not exist in index }, - model_id: 'test-model', + model_id: MODELS[0].modelId, target_field: 'ml.inference.body_content', }, }, @@ -266,8 +295,8 @@ describe('MlInferenceLogic', () => { { disabled: true, disabledReason: expect.stringContaining('title, body_content'), - modelId: 'test-model', - modelType: '', + modelId: MODELS[0].modelId, + modelType: 'ner', pipelineName: 'unit-test', sourceFields: ['title', 'body', 'body_content'], indexFields: ['body'], @@ -306,7 +335,7 @@ describe('MlInferenceLogic', () => { it('filters pipeline if pipeline already attached', () => { FetchMlInferencePipelineProcessorsApiLogic.actions.apiSuccess([ { - modelId: 'test-model', + modelId: MODELS[0].modelId, modelState: TrainedModelState.Started, pipelineName: 'unit-test', pipelineReferences: ['test@ml-inference'], @@ -321,7 +350,7 @@ describe('MlInferenceLogic', () => { field_map: { body: 'text_field', }, - model_id: 'test-model', + model_id: MODELS[0].modelId, target_field: 'ml.inference.test-field', }, }, @@ -333,165 +362,6 @@ describe('MlInferenceLogic', () => { expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([]); }); }); - describe('mlInferencePipeline', () => { - it('returns undefined when configuration is invalid', () => { - MLInferenceLogic.actions.setInferencePipelineConfiguration({ - modelID: '', - pipelineName: '', // Invalid - fieldMappings: [], // Invalid - targetField: '', - }); - - expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined(); - }); - it('generates inference pipeline', () => { - MLModelsApiLogic.actions.apiSuccess([nerModel]); - MLInferenceLogic.actions.setInferencePipelineConfiguration({ - modelID: nerModel.model_id, - pipelineName: 'unit-test', - fieldMappings: [ - { - sourceField: 'body', - targetField: 'ml.inference.body', - }, - ], - targetField: '', - }); - - expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined(); - }); - it('returns undefined when existing pipeline not yet selected', () => { - MLInferenceLogic.actions.setInferencePipelineConfiguration({ - existingPipeline: true, - modelID: '', - pipelineName: '', - fieldMappings: [], - targetField: '', - }); - expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined(); - }); - it('return existing pipeline when selected', () => { - const existingPipeline = { - description: 'this is a test', - processors: [], - version: 1, - }; - FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ - 'unit-test': existingPipeline, - }); - MLInferenceLogic.actions.setInferencePipelineConfiguration({ - existingPipeline: true, - modelID: '', - pipelineName: 'unit-test', - fieldMappings: [ - { - sourceField: 'body', - targetField: 'ml.inference.body', - }, - ], - targetField: '', - }); - expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined(); - expect(MLInferenceLogic.values.mlInferencePipeline).toEqual(existingPipeline); - }); - }); - describe('supportedMLModels', () => { - it('filters unsupported ML models', () => { - MLModelsApiLogic.actions.apiSuccess([ - { - inference_config: { - ner: {}, - }, - input: { - field_names: ['text_field'], - }, - model_id: 'ner-mocked-model', - model_type: 'pytorch', - tags: [], - version: '1', - }, - { - inference_config: { - some_unsupported_task_type: {}, - }, - input: { - field_names: ['text_field'], - }, - model_id: 'unsupported-mocked-model', - model_type: 'pytorch', - tags: [], - version: '1', - }, - ]); - - expect(MLInferenceLogic.values.supportedMLModels).toEqual([ - expect.objectContaining({ - inference_config: { - ner: {}, - }, - }), - ]); - }); - - it('promotes text_expansion ML models and sorts others by ID', () => { - MLModelsApiLogic.actions.apiSuccess([ - { - inference_config: { - ner: {}, - }, - input: { - field_names: ['text_field'], - }, - model_id: 'ner-mocked-model', - model_type: 'pytorch', - tags: [], - version: '1', - }, - { - inference_config: { - text_expansion: {}, - }, - input: { - field_names: ['text_field'], - }, - model_id: 'text-expansion-mocked-model', - model_type: 'pytorch', - tags: [], - version: '1', - }, - { - inference_config: { - text_embedding: {}, - }, - input: { - field_names: ['text_field'], - }, - model_id: 'text-embedding-mocked-model', - model_type: 'pytorch', - tags: [], - version: '1', - }, - ]); - - expect(MLInferenceLogic.values.supportedMLModels).toEqual([ - expect.objectContaining({ - inference_config: { - text_expansion: {}, - }, - }), - expect.objectContaining({ - inference_config: { - ner: {}, - }, - }), - expect.objectContaining({ - inference_config: { - text_embedding: {}, - }, - }), - ]); - }); - }); describe('formErrors', () => { it('has errors when configuration is empty', () => { expect(MLInferenceLogic.values.formErrors).toEqual({ @@ -570,6 +440,75 @@ describe('MlInferenceLogic', () => { }); }); }); + describe('mlInferencePipeline', () => { + it('returns undefined when configuration is invalid', () => { + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + modelID: '', + pipelineName: '', // Invalid + fieldMappings: [], // Invalid + targetField: '', + }); + + expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined(); + }); + it('generates inference pipeline', () => { + CachedFetchModelsApiLogic.actions.apiSuccess(MODELS); + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + modelID: MODELS[0].modelId, + pipelineName: 'unit-test', + fieldMappings: [ + { + sourceField: 'body', + targetField: 'ml.inference.body', + }, + ], + targetField: '', + }); + + expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined(); + }); + it('returns undefined when existing pipeline not yet selected', () => { + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + existingPipeline: true, + modelID: '', + pipelineName: '', + fieldMappings: [], + targetField: '', + }); + expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined(); + }); + it('return existing pipeline when selected', () => { + const existingPipeline = { + description: 'this is a test', + processors: [], + version: 1, + }; + FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ + 'unit-test': existingPipeline, + }); + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + existingPipeline: true, + modelID: '', + pipelineName: 'unit-test', + fieldMappings: [ + { + sourceField: 'body', + targetField: 'ml.inference.body', + }, + ], + targetField: '', + }); + expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined(); + expect(MLInferenceLogic.values.mlInferencePipeline).toEqual(existingPipeline); + }); + }); + describe('selectableModels', () => { + it('makes fetch models request', () => { + MLInferenceLogic.actions.fetchModelsApiSuccess(MODELS); + + expect(MLInferenceLogic.values.selectableModels).toBe(MODELS); + }); + }); }); describe('listeners', () => { @@ -615,14 +554,14 @@ describe('MlInferenceLogic', () => { ...mockModelConfiguration, configuration: { ...mockModelConfiguration.configuration, - modelID: textExpansionModel.model_id, + modelID: MODELS[0].modelId, fieldMappings: [], }, }, }); jest.spyOn(MLInferenceLogic.actions, 'makeCreatePipelineRequest'); - MLModelsApiLogic.actions.apiSuccess([textExpansionModel]); + CachedFetchModelsApiLogic.actions.apiSuccess(MODELS); MLInferenceLogic.actions.selectFields(['my_source_field1', 'my_source_field2']); MLInferenceLogic.actions.addSelectedFieldsToMapping(true); MLInferenceLogic.actions.createPipeline(); @@ -630,7 +569,7 @@ describe('MlInferenceLogic', () => { expect(MLInferenceLogic.actions.makeCreatePipelineRequest).toHaveBeenCalledWith({ indexName: mockModelConfiguration.indexName, inferenceConfig: undefined, - modelId: textExpansionModel.model_id, + modelId: MODELS[0].modelId, fieldMappings: [ { sourceField: 'my_source_field1', @@ -648,13 +587,13 @@ describe('MlInferenceLogic', () => { }); describe('startTextExpansionModelSuccess', () => { it('fetches ml models', () => { - jest.spyOn(MLInferenceLogic.actions, 'makeMLModelsRequest'); + jest.spyOn(MLInferenceLogic.actions, 'startPollingModels'); StartTextExpansionModelApiLogic.actions.apiSuccess({ deploymentState: 'started', modelId: 'foo', }); - expect(MLInferenceLogic.actions.makeMLModelsRequest).toHaveBeenCalledWith(undefined); + expect(MLInferenceLogic.actions.startPollingModels).toHaveBeenCalled(); }); }); describe('onAddInferencePipelineStepChange', () => { @@ -673,12 +612,12 @@ describe('MlInferenceLogic', () => { existingPipeline: false, }); jest.spyOn(MLInferenceLogic.actions, 'fetchPipelineByName'); - jest.spyOn(MLInferenceLogic.actions, 'makeMLModelsRequest'); + jest.spyOn(MLInferenceLogic.actions, 'startPollingModels'); MLInferenceLogic.actions.onAddInferencePipelineStepChange(AddInferencePipelineSteps.Fields); expect(MLInferenceLogic.actions.fetchPipelineByName).toHaveBeenCalledWith({ pipelineName: 'ml-inference-unit-test-pipeline', }); - expect(MLInferenceLogic.actions.makeMLModelsRequest).toHaveBeenCalledWith(undefined); + expect(MLInferenceLogic.actions.startPollingModels).toHaveBeenCalled(); }); it('does not trigger pipeline and model fetch existing pipeline is selected', () => { MLInferenceLogic.actions.setInferencePipelineConfiguration({ @@ -688,10 +627,10 @@ describe('MlInferenceLogic', () => { existingPipeline: true, }); jest.spyOn(MLInferenceLogic.actions, 'fetchPipelineByName'); - jest.spyOn(MLInferenceLogic.actions, 'makeMLModelsRequest'); + jest.spyOn(MLInferenceLogic.actions, 'startPollingModels'); MLInferenceLogic.actions.onAddInferencePipelineStepChange(AddInferencePipelineSteps.Fields); expect(MLInferenceLogic.actions.fetchPipelineByName).not.toHaveBeenCalled(); - expect(MLInferenceLogic.actions.makeMLModelsRequest).not.toHaveBeenCalled(); + expect(MLInferenceLogic.actions.startPollingModels).not.toHaveBeenCalled(); }); }); describe('fetchPipelineSuccess', () => { diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts index bdcf23d71a743d..3383a8772f3ce3 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts @@ -14,11 +14,11 @@ import { formatPipelineName, generateMlInferencePipelineBody, getMlInferencePrefixedFieldName, - getMlModelTypesForModelConfig, ML_INFERENCE_PREFIX, parseMlInferenceParametersFromPipeline, } from '../../../../../../../common/ml_inference_pipeline'; import { Status } from '../../../../../../../common/types/api'; +import { MlModel } from '../../../../../../../common/types/ml'; import { MlInferencePipeline } from '../../../../../../../common/types/pipelines'; import { Actions } from '../../../../../shared/api_logic/create_api_logic'; @@ -34,10 +34,10 @@ import { MappingsApiLogic, } from '../../../../api/mappings/mappings_logic'; import { - TrainedModel, - TrainedModelsApiLogicActions, - TrainedModelsApiLogic, -} from '../../../../api/ml_models/ml_trained_models_logic'; + CachedFetchModelsApiLogic, + CachedFetchModlesApiLogicActions, + FetchModelsApiResponse, +} from '../../../../api/ml_models/cached_fetch_models_api_logic'; import { StartTextExpansionModelApiLogic, StartTextExpansionModelApiLogicActions, @@ -68,12 +68,7 @@ import { } from '../../../../api/pipelines/fetch_pipeline'; import { isConnectorIndex } from '../../../../utils/indices'; -import { - getMLType, - isSupportedMLModel, - sortModels, - sortSourceFields, -} from '../../../shared/ml_inference/utils'; +import { getMLType, sortSourceFields } from '../../../shared/ml_inference/utils'; import { PipelinesLogic } from '../pipelines_logic'; import { @@ -143,6 +138,7 @@ export interface MLInferenceProcessorsActions { CreateMlInferencePipelineResponse >['apiSuccess']; createPipeline: () => void; + fetchModelsApiSuccess: CachedFetchModlesApiLogicActions['apiSuccess']; fetchPipelineByName: FetchPipelineApiLogicActions['makeRequest']; fetchPipelineSuccess: FetchPipelineApiLogicActions['apiSuccess']; makeAttachPipelineRequest: Actions< @@ -153,7 +149,6 @@ export interface MLInferenceProcessorsActions { CreateMlInferencePipelineApiLogicArgs, CreateMlInferencePipelineResponse >['makeRequest']; - makeMLModelsRequest: TrainedModelsApiLogicActions['makeRequest']; makeMappingRequest: Actions['makeRequest']; makeMlInferencePipelinesRequest: Actions< FetchMlInferencePipelinesArgs, @@ -164,7 +159,6 @@ export interface MLInferenceProcessorsActions { FetchMlInferencePipelinesArgs, FetchMlInferencePipelinesResponse >['apiSuccess']; - mlModelsApiError: TrainedModelsApiLogicActions['apiError']; onAddInferencePipelineStepChange: (step: AddInferencePipelineSteps) => { step: AddInferencePipelineSteps; }; @@ -181,6 +175,7 @@ export interface MLInferenceProcessorsActions { configuration: InferencePipelineConfiguration; }; setTargetField: (targetFieldName: string) => { targetFieldName: string }; + startPollingModels: CachedFetchModlesApiLogicActions['startPolling']; startTextExpansionModelSuccess: StartTextExpansionModelApiLogicActions['apiSuccess']; } @@ -200,6 +195,7 @@ export interface MLInferenceProcessorsValues { index: CachedFetchIndexApiLogicValues['indexData']; isConfigureStepValid: boolean; isLoading: boolean; + isModelsInitialLoading: boolean; isPipelineDataValid: boolean; isTextExpansionModelSelected: boolean; mappingData: typeof MappingsApiLogic.values.data; @@ -207,11 +203,11 @@ export interface MLInferenceProcessorsValues { mlInferencePipeline: MlInferencePipeline | undefined; mlInferencePipelineProcessors: FetchMlInferencePipelineProcessorsResponse | undefined; mlInferencePipelinesData: FetchMlInferencePipelinesResponse | undefined; - mlModelsData: TrainedModel[] | null; - mlModelsStatus: Status; - selectedMLModel: TrainedModel | null; + modelsData: FetchModelsApiResponse | undefined; + modelsStatus: Status; + selectableModels: MlModel[]; + selectedModel: MlModel | undefined; sourceFields: string[] | undefined; - supportedMLModels: TrainedModel[]; } export const MLInferenceLogic = kea< @@ -238,6 +234,8 @@ export const MLInferenceLogic = kea< }, connect: { actions: [ + CachedFetchModelsApiLogic, + ['apiSuccess as fetchModelsApiSuccess', 'startPolling as startPollingModels'], FetchMlInferencePipelinesApiLogic, [ 'makeRequest as makeMlInferencePipelinesRequest', @@ -245,8 +243,6 @@ export const MLInferenceLogic = kea< ], MappingsApiLogic, ['makeRequest as makeMappingRequest', 'apiError as mappingsApiError'], - TrainedModelsApiLogic, - ['makeRequest as makeMLModelsRequest', 'apiError as mlModelsApiError'], CreateMlInferencePipelineApiLogic, [ 'apiError as createApiError', @@ -260,7 +256,7 @@ export const MLInferenceLogic = kea< 'makeRequest as makeAttachPipelineRequest', ], PipelinesLogic, - ['closeAddMlInferencePipelineModal as closeAddMlInferencePipelineModal'], + ['closeAddMlInferencePipelineModal'], StartTextExpansionModelApiLogic, ['apiSuccess as startTextExpansionModelSuccess'], FetchPipelineApiLogic, @@ -271,21 +267,20 @@ export const MLInferenceLogic = kea< ], ], values: [ + CachedFetchModelsApiLogic, + ['modelsData', 'status as modelsStatus', 'isInitialLoading as isModelsInitialLoading'], CachedFetchIndexApiLogic, ['indexData as index'], FetchMlInferencePipelinesApiLogic, ['data as mlInferencePipelinesData'], MappingsApiLogic, ['data as mappingData', 'status as mappingStatus'], - TrainedModelsApiLogic, - ['data as mlModelsData', 'status as mlModelsStatus'], FetchMlInferencePipelineProcessorsApiLogic, ['data as mlInferencePipelineProcessors'], FetchPipelineApiLogic, ['data as existingPipeline'], ], }, - events: {}, listeners: ({ values, actions }) => ({ attachPipeline: () => { const { @@ -340,11 +335,6 @@ export const MLInferenceLogic = kea< targetField: '', }); }, - setIndexName: ({ indexName }) => { - actions.makeMlInferencePipelinesRequest(undefined); - actions.makeMLModelsRequest(undefined); - actions.makeMappingRequest({ indexName }); - }, mlInferencePipelinesSuccess: (data) => { if ( (data?.length ?? 0) === 0 && @@ -359,7 +349,7 @@ export const MLInferenceLogic = kea< }, startTextExpansionModelSuccess: () => { // Refresh ML models list when the text expansion model is started - actions.makeMLModelsRequest(undefined); + actions.startPollingModels(); }, onAddInferencePipelineStepChange: ({ step }) => { const { @@ -377,12 +367,12 @@ export const MLInferenceLogic = kea< // back to the Configuration step if we find a pipeline with the same name // Re-fetch ML model list to include those that were deployed in this step - actions.makeMLModelsRequest(undefined); + actions.startPollingModels(); } actions.setAddInferencePipelineStep(step); }, fetchPipelineSuccess: () => { - // We found a pipeline with the name go back to configuration step + // We found a pipeline with the name, go back to configuration step actions.setAddInferencePipelineStep(AddInferencePipelineSteps.Configuration); }, }), @@ -509,30 +499,28 @@ export const MLInferenceLogic = kea< }, ], isLoading: [ - () => [selectors.mlModelsStatus, selectors.mappingStatus], - (mlModelsStatus, mappingStatus) => - !API_REQUEST_COMPLETE_STATUSES.includes(mlModelsStatus) || - !API_REQUEST_COMPLETE_STATUSES.includes(mappingStatus), + () => [selectors.mappingStatus], + (mappingStatus: Status) => !API_REQUEST_COMPLETE_STATUSES.includes(mappingStatus), ], isPipelineDataValid: [ () => [selectors.formErrors], (errors: AddInferencePipelineFormErrors) => Object.keys(errors).length === 0, ], isTextExpansionModelSelected: [ - () => [selectors.selectedMLModel], - (model: TrainedModel | null) => !!model?.inference_config?.text_expansion, + () => [selectors.selectedModel], + (model: MlModel | null) => model?.type === 'text_expansion', ], mlInferencePipeline: [ () => [ selectors.isPipelineDataValid, selectors.addInferencePipelineModal, - selectors.mlModelsData, + selectors.modelsData, selectors.mlInferencePipelinesData, ], ( isPipelineDataValid: MLInferenceProcessorsValues['isPipelineDataValid'], { configuration }: MLInferenceProcessorsValues['addInferencePipelineModal'], - models: MLInferenceProcessorsValues['mlModelsData'], + models: MLInferenceProcessorsValues['modelsData'], mlInferencePipelinesData: MLInferenceProcessorsValues['mlInferencePipelinesData'] ) => { if (configuration.existingPipeline) { @@ -546,7 +534,7 @@ export const MLInferenceLogic = kea< return pipeline as MlInferencePipeline; } if (!isPipelineDataValid) return undefined; - const model = models?.find((mlModel) => mlModel.model_id === configuration.modelID); + const model = models?.find((mlModel) => mlModel.modelId === configuration.modelID); if (!model) return undefined; return generateMlInferencePipelineBody({ @@ -581,23 +569,28 @@ export const MLInferenceLogic = kea< .sort(sortSourceFields); }, ], - supportedMLModels: [ - () => [selectors.mlModelsData], - (mlModelsData: MLInferenceProcessorsValues['mlModelsData']) => { - return (mlModelsData?.filter(isSupportedMLModel) ?? []).sort(sortModels); - }, + selectableModels: [ + () => [selectors.modelsData], + (response: FetchModelsApiResponse) => response ?? [], + ], + selectedModel: [ + () => [selectors.selectableModels, selectors.addInferencePipelineModal], + ( + models: MlModel[], + addInferencePipelineModal: MLInferenceProcessorsValues['addInferencePipelineModal'] + ) => models.find((m) => m.modelId === addInferencePipelineModal.configuration.modelID), ], existingInferencePipelines: [ () => [ selectors.mlInferencePipelinesData, selectors.sourceFields, - selectors.supportedMLModels, + selectors.selectableModels, selectors.mlInferencePipelineProcessors, ], ( mlInferencePipelinesData: MLInferenceProcessorsValues['mlInferencePipelinesData'], indexFields: MLInferenceProcessorsValues['sourceFields'], - supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'], + selectableModels: MLInferenceProcessorsValues['selectableModels'], mlInferencePipelineProcessors: MLInferenceProcessorsValues['mlInferencePipelineProcessors'] ) => { if (!mlInferencePipelinesData) { @@ -619,8 +612,8 @@ export const MLInferenceLogic = kea< const sourceFields = fieldMappings?.map((m) => m.sourceField) ?? []; const missingSourceFields = sourceFields.filter((f) => !indexFields?.includes(f)) ?? []; - const mlModel = supportedMLModels.find((model) => model.model_id === modelId); - const modelType = mlModel ? getMLType(getMlModelTypesForModelConfig(mlModel)) : ''; + const mlModel = selectableModels.find((model) => model.modelId === modelId); + const modelType = mlModel ? getMLType(mlModel.types) : ''; const disabledReason = missingSourceFields.length > 0 ? EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELDS(missingSourceFields.join(', ')) @@ -641,18 +634,5 @@ export const MLInferenceLogic = kea< return existingPipelines; }, ], - selectedMLModel: [ - () => [selectors.supportedMLModels, selectors.addInferencePipelineModal], - ( - supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'], - addInferencePipelineModal: MLInferenceProcessorsValues['addInferencePipelineModal'] - ) => { - return ( - supportedMLModels.find( - (model) => model.model_id === addInferencePipelineModal.configuration.modelID - ) ?? null - ); - }, - ], }), }); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx index c8a970751643a2..b08b4697e6cfcb 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx @@ -52,6 +52,9 @@ const DEFAULT_MODEL: MlModel = { threadsPerAllocation: 0, isPlaceholder: false, hasStats: false, + types: ['pytorch', 'ner'], + inputFieldNames: ['title'], + version: '1', }; const MOCK_ACTIONS = { diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts index b0c26aaf8be8c1..8af7d59a1ceec0 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts @@ -8,7 +8,7 @@ import { LogicMounter } from '../../../../../__mocks__/kea_logic'; import { HttpError } from '../../../../../../../common/types/api'; -import { MlModel, MlModelDeploymentState } from '../../../../../../../common/types/ml'; +import { MlModelDeploymentState } from '../../../../../../../common/types/ml'; import { CachedFetchModelsApiLogic } from '../../../../api/ml_models/cached_fetch_models_api_logic'; import { CreateModelApiLogic, @@ -22,31 +22,15 @@ const CREATE_MODEL_API_RESPONSE: CreateModelResponse = { modelId: 'model_1', deploymentState: MlModelDeploymentState.NotDeployed, }; -const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [ - { - modelId: 'model_1', - title: 'Model 1', - type: 'ner', - deploymentState: MlModelDeploymentState.NotDeployed, - startTime: 0, - targetAllocationCount: 0, - nodeAllocationCount: 0, - threadsPerAllocation: 0, - isPlaceholder: false, - hasStats: false, - }, -]; describe('ModelSelectLogic', () => { const { mount } = new LogicMounter(ModelSelectLogic); const { mount: mountCreateModelApiLogic } = new LogicMounter(CreateModelApiLogic); - const { mount: mountCachedFetchModelsApiLogic } = new LogicMounter(CachedFetchModelsApiLogic); const { mount: mountStartModelApiLogic } = new LogicMounter(StartModelApiLogic); beforeEach(() => { jest.clearAllMocks(); mountCreateModelApiLogic(); - mountCachedFetchModelsApiLogic(); mountStartModelApiLogic(); mount(); }); @@ -82,16 +66,6 @@ describe('ModelSelectLogic', () => { }); }); - describe('fetchModels', () => { - it('makes fetch models request', () => { - jest.spyOn(ModelSelectLogic.actions, 'fetchModelsMakeRequest'); - - ModelSelectLogic.actions.fetchModels(); - - expect(ModelSelectLogic.actions.fetchModelsMakeRequest).toHaveBeenCalled(); - }); - }); - describe('startModel', () => { it('makes start model request', () => { const modelId = 'model_1'; @@ -150,14 +124,6 @@ describe('ModelSelectLogic', () => { }); }); - describe('selectableModels', () => { - it('gets models data from API response', () => { - CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE); - - expect(ModelSelectLogic.values.selectableModels).toEqual(FETCH_MODELS_API_DATA_RESPONSE); - }); - }); - describe('isLoading', () => { it('is set to true if the fetch API is loading the first time', () => { CachedFetchModelsApiLogic.actions.apiReset(); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts index 4074ffac92f6b4..09fd2e0ae8f544 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts @@ -10,15 +10,10 @@ import { kea, MakeLogicType } from 'kea'; import { HttpError, Status } from '../../../../../../../common/types/api'; import { MlModel } from '../../../../../../../common/types/ml'; import { getErrorsFromHttpResponse } from '../../../../../shared/flash_messages/handle_api_errors'; -import { - CachedFetchModelsApiLogic, - CachedFetchModlesApiLogicActions, -} from '../../../../api/ml_models/cached_fetch_models_api_logic'; import { CreateModelApiLogic, CreateModelApiLogicActions, } from '../../../../api/ml_models/create_model_api_logic'; -import { FetchModelsApiResponse } from '../../../../api/ml_models/fetch_models_api_logic'; import { StartModelApiLogic, StartModelApiLogicActions, @@ -37,17 +32,13 @@ export interface ModelSelectActions { createModelError: CreateModelApiLogicActions['apiError']; createModelMakeRequest: CreateModelApiLogicActions['makeRequest']; createModelSuccess: CreateModelApiLogicActions['apiSuccess']; - fetchModels: () => void; - fetchModelsError: CachedFetchModlesApiLogicActions['apiError']; - fetchModelsMakeRequest: CachedFetchModlesApiLogicActions['makeRequest']; - fetchModelsSuccess: CachedFetchModlesApiLogicActions['apiSuccess']; setInferencePipelineConfiguration: MLInferenceProcessorsActions['setInferencePipelineConfiguration']; setInferencePipelineConfigurationFromMLInferenceLogic: MLInferenceProcessorsActions['setInferencePipelineConfiguration']; startModel: (modelId: string) => { modelId: string }; startModelError: CreateModelApiLogicActions['apiError']; startModelMakeRequest: StartModelApiLogicActions['makeRequest']; startModelSuccess: StartModelApiLogicActions['apiSuccess']; - startPollingModels: CachedFetchModlesApiLogicActions['startPolling']; + startPollingModels: MLInferenceProcessorsActions['startPollingModels']; } export interface ModelSelectValues { @@ -59,12 +50,12 @@ export interface ModelSelectValues { ingestionMethod: string; ingestionMethodFromIndexViewLogic: string; isLoading: boolean; - isInitialLoading: boolean; + isModelsInitialLoadingFromMLInferenceLogic: boolean; modelStateChangeError: string | undefined; - modelsData: FetchModelsApiResponse | undefined; - modelsStatus: Status; selectableModels: MlModel[]; + selectableModelsFromMLInferenceLogic: MlModel[]; selectedModel: MlModel | undefined; + selectedModelFromMLInferenceLogic: MlModel | undefined; startModelError: HttpError | undefined; startModelStatus: Status; } @@ -72,19 +63,11 @@ export interface ModelSelectValues { export const ModelSelectLogic = kea>({ actions: { createModel: (modelId: string) => ({ modelId }), - fetchModels: true, setInferencePipelineConfiguration: (configuration) => ({ configuration }), startModel: (modelId: string) => ({ modelId }), }, connect: { actions: [ - CachedFetchModelsApiLogic, - [ - 'makeRequest as fetchModelsMakeRequest', - 'apiSuccess as fetchModelsSuccess', - 'apiError as fetchModelsError', - 'startPolling as startPollingModels', - ], CreateModelApiLogic, [ 'makeRequest as createModelMakeRequest', @@ -93,8 +76,9 @@ export const ModelSelectLogic = kea ({ - afterMount: () => { - actions.startPollingModels(); - }, - }), listeners: ({ actions }) => ({ createModel: ({ modelId }) => { actions.createModelMakeRequest({ modelId }); @@ -130,9 +112,6 @@ export const ModelSelectLogic = kea { - actions.fetchModelsMakeRequest({}); - }, setInferencePipelineConfiguration: ({ configuration }) => { actions.setInferencePipelineConfigurationFromMLInferenceLogic(configuration); }, @@ -167,16 +146,16 @@ export const ModelSelectLogic = kea [selectors.modelsData], - (response: FetchModelsApiResponse) => response ?? [], + () => [selectors.selectableModelsFromMLInferenceLogic], + (selectableModels) => selectableModels, // Pass-through ], selectedModel: [ - () => [selectors.selectableModels, selectors.addInferencePipelineModal], - ( - models: MlModel[], - addInferencePipelineModal: MLInferenceProcessorsValues['addInferencePipelineModal'] - ) => models.find((m) => m.modelId === addInferencePipelineModal.configuration.modelID), + () => [selectors.selectedModelFromMLInferenceLogic], + (selectedModel) => selectedModel, // Pass-through + ], + isLoading: [ + () => [selectors.isModelsInitialLoadingFromMLInferenceLogic], + (isModelsInitialLoading) => isModelsInitialLoading, // Pass-through ], - isLoading: [() => [selectors.isInitialLoading], (isInitialLoading) => isInitialLoading], }), }); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx index bcf4eb8342db13..6c4f1f4bbabb87 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx @@ -34,6 +34,9 @@ const DEFAULT_PROPS: EuiSelectableOption = { threadsPerAllocation: 0, isPlaceholder: false, hasStats: false, + types: ['pytorch', 'ner'], + inputFieldNames: ['title'], + version: '1', }; describe('ModelSelectOption', () => { diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/no_models.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/no_models.tsx deleted file mode 100644 index 4670b00e939274..00000000000000 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/no_models.tsx +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import React from 'react'; - -import { EuiEmptyPrompt, EuiImage, EuiLink, EuiText, useEuiTheme } from '@elastic/eui'; - -import { i18n } from '@kbn/i18n'; -import { FormattedMessage } from '@kbn/i18n-react'; - -import noMlModelsGraphicDark from '../../../../../../assets/images/no_ml_models_dark.svg'; -import noMlModelsGraphicLight from '../../../../../../assets/images/no_ml_models_light.svg'; - -import { docLinks } from '../../../../../shared/doc_links'; - -export const NoModelsPanel: React.FC = () => { - const { colorMode } = useEuiTheme(); - - return ( - - - -

- - {i18n.translate( - 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.esDocs.link', - { defaultMessage: 'Learn how to add a trained model' } - )} - - ), - }} - /> -

-
- - } - /> - ); -}; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.test.ts index 4ddf5b1c4b77ae..1e4eb177445172 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.test.ts @@ -4,89 +4,10 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import { nerModel, textClassificationModel } from '../../../__mocks__/ml_models.mock'; -import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models'; - -import { - getMLType, - getModelDisplayTitle, - isSupportedMLModel, - sortSourceFields, - NLP_CONFIG_KEYS, -} from './utils'; +import { getMLType, getModelDisplayTitle, sortSourceFields, NLP_CONFIG_KEYS } from './utils'; describe('ml inference utils', () => { - describe('isSupportedMLModel', () => { - const makeFakeModel = ( - config: Partial - ): TrainedModelConfigResponse => { - const { inference_config: _throwAway, ...base } = nerModel; - return { - inference_config: {}, - ...base, - ...config, - }; - }; - it('returns true for expected models', () => { - const models: TrainedModelConfigResponse[] = [ - nerModel, - textClassificationModel, - makeFakeModel({ - inference_config: { - text_embedding: {}, - }, - model_id: 'mock-text_embedding', - }), - makeFakeModel({ - inference_config: { - zero_shot_classification: { - classification_labels: [], - }, - }, - model_id: 'mock-zero_shot_classification', - }), - makeFakeModel({ - inference_config: { - question_answering: {}, - }, - model_id: 'mock-question_answering', - }), - makeFakeModel({ - inference_config: { - fill_mask: {}, - }, - model_id: 'mock-fill_mask', - }), - makeFakeModel({ - inference_config: { - classification: {}, - }, - model_id: 'lang_ident_model_1', - model_type: 'lang_ident', - }), - ]; - - for (const model of models) { - expect(isSupportedMLModel(model)).toBe(true); - } - }); - - it('returns false for unexpected models', () => { - const models: TrainedModelConfigResponse[] = [ - makeFakeModel({}), - makeFakeModel({ - inference_config: { - fakething: {}, - }, - }), - ]; - - for (const model of models) { - expect(isSupportedMLModel(model)).toBe(false); - } - }); - }); describe('sortSourceFields', () => { it('promotes fields', () => { let fields: string[] = ['id', 'body', 'url']; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts index 3d97f52c659c14..88d273fcef135d 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts @@ -6,12 +6,9 @@ */ import { i18n } from '@kbn/i18n'; -import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models'; import { TRAINED_MODEL_TYPE, SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils'; -import { TrainedModel } from '../../../api/ml_models/ml_trained_models_logic'; - export const NLP_CONFIG_KEYS: string[] = Object.values(SUPPORTED_PYTORCH_TASKS); export const RECOMMENDED_FIELDS = ['body', 'body_content', 'title']; @@ -51,13 +48,6 @@ export const NLP_DISPLAY_TITLES: Record = { ), }; -export const isSupportedMLModel = (model: TrainedModelConfigResponse): boolean => { - return ( - Object.keys(model.inference_config || {}).some((key) => NLP_CONFIG_KEYS.includes(key)) || - model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT - ); -}; - export const sortSourceFields = (a: string, b: string): number => { const promoteA = RECOMMENDED_FIELDS.includes(a); const promoteB = RECOMMENDED_FIELDS.includes(b); @@ -82,16 +72,3 @@ export const getMLType = (modelTypes: string[]): string => { }; export const getModelDisplayTitle = (type: string): string | undefined => NLP_DISPLAY_TITLES[type]; - -export const isTextExpansionModel = (model: TrainedModel): boolean => - Boolean(model.inference_config?.text_expansion); - -/** - * Sort function for displaying a list of models. Promotes text_expansion models and sorts the rest by model ID. - */ -export const sortModels = (m1: TrainedModel, m2: TrainedModel) => - isTextExpansionModel(m1) - ? -1 - : isTextExpansionModel(m2) - ? 1 - : m1.model_id.localeCompare(m2.model_id); diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.test.ts b/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.test.ts index bd02af095fe04d..b96bfe6de8e7f5 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.test.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.test.ts @@ -72,12 +72,18 @@ describe('fetchMlModels', () => { inference_config: { text_embedding: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: 'model_1', inference_config: { text_classification: {}, }, + input: { + fields: ['text_field'], + }, }, ], }; @@ -160,24 +166,36 @@ describe('fetchMlModels', () => { inference_config: { text_embedding: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: E5_LINUX_OPTIMIZED_MODEL_ID, inference_config: { text_embedding: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: ELSER_MODEL_ID, inference_config: { text_expansion: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: ELSER_LINUX_OPTIMIZED_MODEL_ID, inference_config: { text_expansion: {}, }, + input: { + fields: ['text_field'], + }, }, ], }; @@ -210,24 +228,36 @@ describe('fetchMlModels', () => { inference_config: { text_embedding: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: E5_LINUX_OPTIMIZED_MODEL_ID, inference_config: { text_embedding: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: ELSER_MODEL_ID, inference_config: { text_expansion: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: ELSER_LINUX_OPTIMIZED_MODEL_ID, inference_config: { text_expansion: {}, }, + input: { + fields: ['text_field'], + }, }, ], }; @@ -265,18 +295,27 @@ describe('fetchMlModels', () => { inference_config: { text_expansion: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: E5_MODEL_ID, inference_config: { text_embedding: {}, }, + input: { + fields: ['text_field'], + }, }, { model_id: 'model_1', inference_config: { ner: {}, }, + input: { + fields: ['text_field'], + }, }, ], }; @@ -337,6 +376,9 @@ describe('fetchMlModels', () => { inference_config: { text_expansion: {}, }, + input: { + fields: ['text_field'], + }, }, ], }; @@ -385,18 +427,27 @@ describe('fetchMlModels', () => { inference_config: { ner: {}, // "Named Entity Recognition" }, + input: { + fields: ['text_field'], + }, }, { model_id: 'model_2', inference_config: { text_embedding: {}, // "Dense Vector Text Embedding" }, + input: { + fields: ['text_field'], + }, }, { model_id: 'model_3', inference_config: { text_classification: {}, // "Text Classification" }, + input: { + fields: ['text_field'], + }, }, ], }; diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.ts b/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.ts index c1af4ab69c0bc4..c39b6ec146ea10 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/fetch_ml_models.ts @@ -6,9 +6,12 @@ */ import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types'; + import { i18n } from '@kbn/i18n'; import { MlTrainedModels } from '@kbn/ml-plugin/server'; +import { getMlModelTypesForModelConfig } from '../../../common/ml_inference_pipeline'; + import { MlModelDeploymentState, MlModel } from '../../../common/types/ml'; import { @@ -109,39 +112,39 @@ export const fetchCompatiblePromotedModelIds = async (trainedModelsProvider: MlT }; const getModel = (modelConfig: MlTrainedModelConfig, modelStats?: MlTrainedModelStats): MlModel => { - { - const modelId = modelConfig.model_id; - const type = modelConfig.inference_config ? Object.keys(modelConfig.inference_config)[0] : ''; - const model = { - ...BASE_MODEL, - modelId, - type, - title: getUserFriendlyTitle(modelId, type), - isPromoted: [ - ELSER_MODEL_ID, - ELSER_LINUX_OPTIMIZED_MODEL_ID, - E5_MODEL_ID, - E5_LINUX_OPTIMIZED_MODEL_ID, - ].includes(modelId), - }; - - // Enrich deployment stats - if (modelStats && modelStats.deployment_stats) { - model.hasStats = true; - model.deploymentState = getDeploymentState( - modelStats.deployment_stats.allocation_status.state - ); - model.nodeAllocationCount = modelStats.deployment_stats.allocation_status.allocation_count; - model.targetAllocationCount = - modelStats.deployment_stats.allocation_status.target_allocation_count; - model.threadsPerAllocation = modelStats.deployment_stats.threads_per_allocation; - model.startTime = modelStats.deployment_stats.start_time; - } else if (model.modelId === LANG_IDENT_MODEL_ID) { - model.deploymentState = MlModelDeploymentState.FullyAllocated; - } - - return model; + const modelId = modelConfig.model_id; + const type = modelConfig.inference_config ? Object.keys(modelConfig.inference_config)[0] : ''; + const model = { + ...BASE_MODEL, + modelId, + type, + title: getUserFriendlyTitle(modelId, type), + description: modelConfig.description, + types: getMlModelTypesForModelConfig(modelConfig), + inputFieldNames: modelConfig.input.field_names, + version: modelConfig.version, + isPromoted: [ + ELSER_MODEL_ID, + ELSER_LINUX_OPTIMIZED_MODEL_ID, + E5_MODEL_ID, + E5_LINUX_OPTIMIZED_MODEL_ID, + ].includes(modelId), + }; + + // Enrich deployment stats + if (modelStats && modelStats.deployment_stats) { + model.hasStats = true; + model.deploymentState = getDeploymentState(modelStats.deployment_stats.allocation_status.state); + model.nodeAllocationCount = modelStats.deployment_stats.allocation_status.allocation_count; + model.targetAllocationCount = + modelStats.deployment_stats.allocation_status.target_allocation_count; + model.threadsPerAllocation = modelStats.deployment_stats.threads_per_allocation; + model.startTime = modelStats.deployment_stats.start_time; + } else if (model.modelId === LANG_IDENT_MODEL_ID) { + model.deploymentState = MlModelDeploymentState.FullyAllocated; } + + return model; }; const enrichModelWithDownloadStatus = async ( diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts b/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts index 19a43059d3a086..d063fd158385a2 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts @@ -60,6 +60,8 @@ export const BASE_MODEL = { threadsPerAllocation: 0, isPlaceholder: false, hasStats: false, + types: [], + inputFieldNames: [], }; export const ELSER_MODEL_PLACEHOLDER: MlModel = { diff --git a/x-pack/plugins/translations/translations/fr-FR.json b/x-pack/plugins/translations/translations/fr-FR.json index a9b62b8622cdc3..da642ec3e02ffe 100644 --- a/x-pack/plugins/translations/translations/fr-FR.json +++ b/x-pack/plugins/translations/translations/fr-FR.json @@ -12766,7 +12766,6 @@ "xpack.enterpriseSearch.content.indices.connectorScheduling.page.description": "Votre connecteur est désormais déployé. Vous pouvez planifier du contenu récurrent et accéder aux synchronisations de contrôle ici. Si vous souhaitez exécuter un test rapide, lancez une synchronisation unique à l’aide du bouton {sync}.", "xpack.enterpriseSearch.content.indices.connectorScheduling.schedulePanel.documentLevelSecurity.dlsDisabledCallout.text": "{link} pour ce connecteur afin d'activer ces options.", "xpack.enterpriseSearch.content.indices.deleteIndex.successToast.title": "Votre index {indexName} et toute configuration d'ingestion associée ont été supprimés avec succès", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.description": "Aucun de vos modèles entraînés de Machine Learning ne peut être utilisé par un pipeline d'inférence. {documentationLink}", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.missingSourceFieldsDescription": "Champs manquants dans cet index : {commaSeparatedMissingSourceFields}", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.name.helpText": "Les noms de pipeline sont uniques dans un déploiement, et ils peuvent uniquement contenir des lettres, des chiffres, des traits de soulignement et des traits d'union. Cela créera un pipeline nommé {pipelineName}.", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.fields.descriptionReview": "Vérifiez les mappings des champs du pipeline que vous avez choisi afin de vous assurer que les champs source et cible correspondent à votre cas d'utilisation spécifique. {notEditable}", @@ -14190,8 +14189,6 @@ "xpack.enterpriseSearch.content.indices.extractionRules.editRule.url.urlFiltersLink": "En savoir plus sur les filtres d'URL", "xpack.enterpriseSearch.content.indices.extractionRules.editRule.urlLabel": "URL", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.createErrors": "Erreur lors de la création d'un pipeline", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.esDocs.link": "Découvrir comment ajouter un modèle entraîné", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.imageAlt": "Illustration d'absence de modèles de Machine Learning", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.description": "Créez ou réutilisez un pipeline enfant qui servira de processeur dans votre pipeline principal.", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.emptyValueError": "Champ obligatoire.", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipelineLabel": "Sélectionner un pipeline d'inférence existant", diff --git a/x-pack/plugins/translations/translations/ja-JP.json b/x-pack/plugins/translations/translations/ja-JP.json index 77335897bd2080..a754e686648769 100644 --- a/x-pack/plugins/translations/translations/ja-JP.json +++ b/x-pack/plugins/translations/translations/ja-JP.json @@ -12779,7 +12779,6 @@ "xpack.enterpriseSearch.content.indices.connectorScheduling.page.description": "コネクターがデプロイされました。ここで、繰り返しコンテンツとアクセス制御同期をスケジュールします。簡易テストを実行する場合は、{sync}ボタンを使用してワンタイム同期を実行します。", "xpack.enterpriseSearch.content.indices.connectorScheduling.schedulePanel.documentLevelSecurity.dlsDisabledCallout.text": "これらのオプションを有効にするには、このコネクターの{link}。", "xpack.enterpriseSearch.content.indices.deleteIndex.successToast.title": "インデックス{indexName}と関連付けられたすべての統合構成が正常に削除されました", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.description": "推論パイプラインで使用できる学習済み機械学習モデルがありません。{documentationLink}", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.missingSourceFieldsDescription": "このインデックスで欠落しているフィールド:{commaSeparatedMissingSourceFields}", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.name.helpText": "パイプライン名はデプロイ内で一意であり、文字、数字、アンダースコア、ハイフンのみを使用できます。これにより、{pipelineName}という名前のパイプラインが作成されます。", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.fields.descriptionReview": "選択したパイプラインのフィールドマッピングを調べ、ソースフィールドとターゲットフィールドが特定のユースケースに適合していることを確認します。{notEditable}", @@ -14203,8 +14202,6 @@ "xpack.enterpriseSearch.content.indices.extractionRules.editRule.url.urlFiltersLink": "URLフィルターの詳細をご覧ください", "xpack.enterpriseSearch.content.indices.extractionRules.editRule.urlLabel": "URL", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.createErrors": "パイプラインの作成エラー", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.esDocs.link": "学習されたモデルの追加方法の詳細", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.imageAlt": "機械学習モデル例がありません", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.description": "メインパイプラインでプロセッサーとして使用される子パイプラインを作成または再利用します。", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.emptyValueError": "フィールドが必要です。", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipelineLabel": "既存の推論パイプラインを選択", diff --git a/x-pack/plugins/translations/translations/zh-CN.json b/x-pack/plugins/translations/translations/zh-CN.json index 70fe8f8c912983..ab8c10ed3ed97f 100644 --- a/x-pack/plugins/translations/translations/zh-CN.json +++ b/x-pack/plugins/translations/translations/zh-CN.json @@ -12873,7 +12873,6 @@ "xpack.enterpriseSearch.content.indices.connectorScheduling.page.description": "现已部署您的连接器。在此处计划重复内容和访问控制同步。如果要运行快速测试,请使用 {sync} 按钮启动一次性同步。", "xpack.enterpriseSearch.content.indices.connectorScheduling.schedulePanel.documentLevelSecurity.dlsDisabledCallout.text": "此连接器的 {link},用于激活这些选项。", "xpack.enterpriseSearch.content.indices.deleteIndex.successToast.title": "您的索引 {indexName} 和任何关联的采集配置已成功删除", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.description": "您没有可供推理管道使用的已训练 Machine Learning 模型。{documentationLink}", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.missingSourceFieldsDescription": "此索引中缺少字段:{commaSeparatedMissingSourceFields}", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.name.helpText": "管道名称在部署内唯一,并且只能包含字母、数字、下划线和连字符。这会创建名为 {pipelineName} 的管道。", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.fields.descriptionReview": "检查您选择的管道的字段映射,确保源和目标字段符合您的特定用例。{notEditable}", @@ -14297,8 +14296,6 @@ "xpack.enterpriseSearch.content.indices.extractionRules.editRule.url.urlFiltersLink": "详细了解 URL 筛选", "xpack.enterpriseSearch.content.indices.extractionRules.editRule.urlLabel": "URL", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.createErrors": "创建管道时出错", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.esDocs.link": "了解如何添加已训练模型", - "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.noModels.imageAlt": "无 Machine Learning 模型图示", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.description": "构建或重复使用将在您的主管道中用作处理器的子管道。", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.emptyValueError": "“字段”必填。", "xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipelineLabel": "选择现有推理管道",