diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequest.java b/server/src/main/java/org/opensearch/action/search/SearchRequest.java index 7462faea6ed8d..7ce52d9395b8b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequest.java @@ -51,6 +51,7 @@ import org.opensearch.search.builder.PointInTimeBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.SearchContext; +import org.opensearch.server.proto.SearchRequestProto; import java.io.IOException; import java.util.Arrays; @@ -274,6 +275,18 @@ public SearchRequest(StreamInput in) throws IOException { } } + public SearchRequest(byte[] in) throws IOException { + this(); + SearchRequestProto.SearchRequest searchRequestProto = SearchRequestProto.SearchRequest.parseFrom(in); + indices = searchRequestProto.getIndicesList().toArray(new String[0]); + routing = searchRequestProto.getRouting(); + preference = searchRequestProto.getPreference(); + batchedReduceSize = searchRequestProto.getBatchedReduceSize(); + source = new SearchSourceBuilder(searchRequestProto.getSourceBuilder()); + searchType = SearchType.QUERY_THEN_FETCH; + batchedReduceSize = DEFAULT_BATCHED_REDUCE_SIZE; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/server/src/main/java/org/opensearch/index/query/AbstractQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/AbstractQueryBuilder.java index 66c6ee115c3f0..4bc6785f92c85 100644 --- a/server/src/main/java/org/opensearch/index/query/AbstractQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/AbstractQueryBuilder.java @@ -86,6 +86,12 @@ protected AbstractQueryBuilder(StreamInput in) throws IOException { queryName = in.readOptionalString(); } + protected AbstractQueryBuilder(String queryName, float boost) throws IOException { + this.boost = boost; + checkNegativeBoost(boost); + this.queryName = queryName; + } + @Override public final void writeTo(StreamOutput out) throws IOException { out.writeFloat(boost); diff --git a/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java index 5e9e6a3660e76..cacee82b02f31 100644 --- a/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/MatchQueryBuilder.java @@ -46,6 +46,7 @@ import org.opensearch.index.query.support.QueryParsers; import org.opensearch.index.search.MatchQuery; import org.opensearch.index.search.MatchQuery.ZeroTermsQuery; +import org.opensearch.server.proto.MatchQueryProto; import java.io.IOException; import java.util.Objects; @@ -129,6 +130,14 @@ public MatchQueryBuilder(String fieldName, Object value) { this.value = value; } + public MatchQueryBuilder(MatchQueryProto.MatchQuery queryProto) throws IOException { + super("", DEFAULT_BOOST); + fieldName = queryProto.getFieldName(); + value = queryProto.getValue(); + boost = DEFAULT_BOOST; + queryName = "match"; + } + /** * Read from a stream. */ diff --git a/server/src/main/java/org/opensearch/index/query/TermQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/TermQueryBuilder.java index 8a834395e2e2d..a77fb1165c9ab 100644 --- a/server/src/main/java/org/opensearch/index/query/TermQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/TermQueryBuilder.java @@ -44,6 +44,8 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.ConstantFieldType; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.server.proto.MatchQueryProto; +import org.opensearch.server.proto.TermQueryProto; import java.io.IOException; import java.util.Objects; @@ -117,6 +119,16 @@ public TermQueryBuilder(StreamInput in) throws IOException { } } + + public TermQueryBuilder(TermQueryProto.TermQuery queryProto) throws IOException { + super(queryProto.getFieldName(), queryProto.getValue()); + boost(DEFAULT_BOOST); + if (queryProto.hasCaseInsensitive()) { + caseInsensitive(queryProto.getCaseInsensitive()); + } + } + + @Override protected void doWriteTo(StreamOutput out) throws IOException { super.doWriteTo(out); diff --git a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java index 3a6b45013e892..6659a04231de4 100644 --- a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java @@ -31,7 +31,7 @@ */ package org.opensearch.rest.action.search; - +import java.util.Map; import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.search.SearchAction; @@ -39,6 +39,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.IndicesOptions; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.Booleans; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; @@ -49,7 +50,6 @@ import org.opensearch.rest.action.RestActions; import org.opensearch.rest.action.RestCancellableNodeClient; import org.opensearch.rest.action.RestStatusToXContentListener; -import org.opensearch.search.Scroll; import org.opensearch.search.SearchService; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.StoredFieldsContext; @@ -60,17 +60,12 @@ import org.opensearch.search.suggest.term.TermSuggestionBuilder.SuggestMode; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; import java.util.function.IntConsumer; import static java.util.Arrays.asList; import static java.util.Collections.unmodifiableList; import static org.opensearch.action.ValidateActions.addValidationError; -import static org.opensearch.common.unit.TimeValue.parseTimeValue; import static org.opensearch.rest.RestRequest.Method.GET; import static org.opensearch.rest.RestRequest.Method.POST; import static org.opensearch.search.suggest.SuggestBuilders.termSuggestion; @@ -106,17 +101,17 @@ public String getName() { public List routes() { return unmodifiableList( asList( - new Route(GET, "/_search"), - new Route(POST, "/_search"), - new Route(GET, "/{index}/_search"), - new Route(POST, "/{index}/_search") + new Route(GET, "/_search_proto"), + new Route(POST, "/_search_proto"), + new Route(GET, "/{index}/_search_proto"), + new Route(POST, "/{index}/_search_proto") ) ); } @Override public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { - SearchRequest searchRequest = new SearchRequest(); +// SearchRequest searchRequest = new SearchRequest(); /* * We have to pull out the call to `source().size(size)` because * _update_by_query and _delete_by_query uses this same parsing @@ -129,11 +124,12 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC * be null later. If that is confusing to you then you are in good * company. */ - IntConsumer setSize = size -> searchRequest.source().size(size); - request.withContentOrSourceParamParserOrNull( - parser -> parseSearchRequest(searchRequest, request, parser, client.getNamedWriteableRegistry(), setSize) - ); + BytesReference content = request.content(); + byte[] contentBytes = BytesReference.toBytes(content); + SearchRequest searchRequest = new SearchRequest(contentBytes); + IntConsumer setSize = size -> searchRequest.source().size(size); + request.param("index"); return channel -> { RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel()); cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel)); @@ -147,6 +143,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC * parameter * @param setSize how the size url parameter is handled. {@code udpate_by_query} and regular search differ here. */ + // SKIP as we use proto public static void parseSearchRequest( SearchRequest searchRequest, RestRequest request, @@ -155,75 +152,77 @@ public static void parseSearchRequest( IntConsumer setSize ) throws IOException { - if (searchRequest.source() == null) { - searchRequest.source(new SearchSourceBuilder()); - } - searchRequest.indices(Strings.splitStringByCommaToArray(request.param("index"))); - if (requestContentParser != null) { - searchRequest.source().parseXContent(requestContentParser, true); - } - - final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize()); - searchRequest.setBatchedReduceSize(batchedReduceSize); - if (request.hasParam("pre_filter_shard_size")) { - searchRequest.setPreFilterShardSize(request.paramAsInt("pre_filter_shard_size", SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE)); - } - - if (request.hasParam("max_concurrent_shard_requests")) { - // only set if we have the parameter since we auto adjust the max concurrency on the coordinator - // based on the number of nodes in the cluster - final int maxConcurrentShardRequests = request.paramAsInt( - "max_concurrent_shard_requests", - searchRequest.getMaxConcurrentShardRequests() - ); - searchRequest.setMaxConcurrentShardRequests(maxConcurrentShardRequests); - } - - if (request.hasParam("allow_partial_search_results")) { - // only set if we have the parameter passed to override the cluster-level default - searchRequest.allowPartialSearchResults(request.paramAsBoolean("allow_partial_search_results", null)); - } - - if (request.hasParam("phase_took")) { - // only set if we have the parameter passed to override the cluster-level default - // else phaseTook = null - searchRequest.setPhaseTook(request.paramAsBoolean("phase_took", true)); - } - - // do not allow 'query_and_fetch' or 'dfs_query_and_fetch' search types - // from the REST layer. these modes are an internal optimization and should - // not be specified explicitly by the user. - String searchType = request.param("search_type"); - if ("query_and_fetch".equals(searchType) || "dfs_query_and_fetch".equals(searchType)) { - throw new IllegalArgumentException("Unsupported search type [" + searchType + "]"); - } else { - searchRequest.searchType(searchType); - } - parseSearchSource(searchRequest.source(), request, setSize); - searchRequest.requestCache(request.paramAsBoolean("request_cache", searchRequest.requestCache())); - - String scroll = request.param("scroll"); - if (scroll != null) { - searchRequest.scroll(new Scroll(parseTimeValue(scroll, null, "scroll"))); - } - - searchRequest.routing(request.param("routing")); - searchRequest.preference(request.param("preference")); - searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions())); - searchRequest.pipeline(request.param("search_pipeline")); - - checkRestTotalHits(request, searchRequest); - request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false); - - if (searchRequest.pointInTimeBuilder() != null) { - preparePointInTime(searchRequest, request, namedWriteableRegistry); - } else { - searchRequest.setCcsMinimizeRoundtrips( - request.paramAsBoolean("ccs_minimize_roundtrips", searchRequest.isCcsMinimizeRoundtrips()) - ); - } - - searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null)); +// if (searchRequest.source() == null) { +// searchRequest.source(new SearchSourceBuilder()); +// } +// System.out.println("========= request.param(\"index\") " + request.param("index")); +// searchRequest.indices(Strings.splitStringByCommaToArray(request.param("index"))); +// if (requestContentParser != null) { +// searchRequest.source().parseXContent(requestContentParser, true); +// } +// +// final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize()); +// searchRequest.setBatchedReduceSize(batchedReduceSize); +// if (request.hasParam("pre_filter_shard_size")) { +// searchRequest.setPreFilterShardSize(request.paramAsInt("pre_filter_shard_size", SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE)); +// } +// +// if (request.hasParam("max_concurrent_shard_requests")) { +// // only set if we have the parameter since we auto adjust the max concurrency on the coordinator +// // based on the number of nodes in the cluster +// final int maxConcurrentShardRequests = request.paramAsInt( +// "max_concurrent_shard_requests", +// searchRequest.getMaxConcurrentShardRequests() +// ); +// searchRequest.setMaxConcurrentShardRequests(maxConcurrentShardRequests); +// } +// +// if (request.hasParam("allow_partial_search_results")) { +// // only set if we have the parameter passed to override the cluster-level default +// searchRequest.allowPartialSearchResults(request.paramAsBoolean("allow_partial_search_results", null)); +// } +// +// if (request.hasParam("phase_took")) { +// // only set if we have the parameter passed to override the cluster-level default +// // else phaseTook = null +// searchRequest.setPhaseTook(request.paramAsBoolean("phase_took", true)); +// } +// +// // do not allow 'query_and_fetch' or 'dfs_query_and_fetch' search types +// // from the REST layer. these modes are an internal optimization and should +// // not be specified explicitly by the user. +// String searchType = request.param("search_type"); +// if ("query_and_fetch".equals(searchType) || "dfs_query_and_fetch".equals(searchType)) { +// throw new IllegalArgumentException("Unsupported search type [" + searchType + "]"); +// } else { +// searchRequest.searchType(searchType); +// } +// System.out.println("================ setSize: " + setSize); +// parseSearchSource(searchRequest.source(), request, setSize); +// searchRequest.requestCache(request.paramAsBoolean("request_cache", searchRequest.requestCache())); +// +// String scroll = request.param("scroll"); +// if (scroll != null) { +// searchRequest.scroll(new Scroll(parseTimeValue(scroll, null, "scroll"))); +// } +// +// searchRequest.routing(request.param("routing")); +// searchRequest.preference(request.param("preference")); +// searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions())); +// searchRequest.pipeline(request.param("search_pipeline")); +// +// checkRestTotalHits(request, searchRequest); +// request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false); +// +// if (searchRequest.pointInTimeBuilder() != null) { +// preparePointInTime(searchRequest, request, namedWriteableRegistry); +// } else { +// searchRequest.setCcsMinimizeRoundtrips( +// request.paramAsBoolean("ccs_minimize_roundtrips", searchRequest.isCcsMinimizeRoundtrips()) +// ); +// } +// +// searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null)); } /** @@ -304,6 +303,7 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil } } + // searchSource sort String sSorts = request.param("sort"); if (sSorts != null) { String[] sorts = Strings.splitStringByCommaToArray(sSorts); @@ -322,7 +322,6 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil } } } - String sStats = request.param("stats"); if (sStats != null) { searchSourceBuilder.stats(Arrays.asList(Strings.splitStringByCommaToArray(sStats))); diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index f65baa7df5cd7..3e591b3cbcfed 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -54,6 +54,14 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.DerivedField; import org.opensearch.index.mapper.DerivedFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.server.proto.BoolQueryProto; +import org.opensearch.server.proto.MatchQueryProto; +import org.opensearch.server.proto.SearchRequestProto; +import org.opensearch.server.proto.TermQueryProto; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.Rewriteable; @@ -230,6 +238,40 @@ public static HighlightBuilder highlight() { */ public SearchSourceBuilder() {} + public SearchSourceBuilder(SearchRequestProto.SearchRequest.SourceBuilder sourceBuilderProto) throws IOException { + this(); + if (sourceBuilderProto.hasFrom()) { + from = sourceBuilderProto.getFrom(); + } + if (sourceBuilderProto.hasSize()) { + size = sourceBuilderProto.getSize(); + } + if (sourceBuilderProto.hasTerminateAfter()) { + terminateAfter = sourceBuilderProto.getTerminateAfter(); + } + if (sourceBuilderProto.hasQuery() && sourceBuilderProto.getQuery().hasBool()) { + BoolQueryProto.BoolQuery boolQueryProto = sourceBuilderProto.getQuery().getBool(); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + + // Handle must clauses + for (MatchQueryProto.MatchQuery matchQueryProto : boolQueryProto.getMustList()) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(matchQueryProto.getFieldName(), matchQueryProto.getValue()); + boolQueryBuilder.must(matchQueryBuilder); + } + + // Handle filter clauses + for (TermQueryProto.TermQuery termQueryProto : boolQueryProto.getFilterList()) { + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(termQueryProto.getFieldName(), termQueryProto.getValue()); + boolQueryBuilder.filter(termQueryBuilder); + } + + queryBuilder = boolQueryBuilder; + } else { + throw new IllegalArgumentException("Expected a BoolQuery in the source builder"); + } + + } + /** * Read from a stream. */ diff --git a/server/src/main/proto/server/BoolQueryProto.proto b/server/src/main/proto/server/BoolQueryProto.proto new file mode 100644 index 0000000000000..c9fb2f9058dbf --- /dev/null +++ b/server/src/main/proto/server/BoolQueryProto.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/MatchQueryProto.proto"; +import "server/TermQueryProto.proto"; + +message BoolQuery { + repeated MatchQuery must = 1; + repeated TermQuery filter = 2; +} diff --git a/server/src/main/proto/server/MatchQueryProto.proto b/server/src/main/proto/server/MatchQueryProto.proto new file mode 100644 index 0000000000000..42071a9ee5bab --- /dev/null +++ b/server/src/main/proto/server/MatchQueryProto.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; +package org.opensearch.server.proto; + +message MatchQuery { + optional string name = 1; + optional string fieldName = 2; + // TODO: change to object + optional string value = 3; + optional string analyze = 4; + optional string minimumShouldMatch = 5; +} diff --git a/server/src/main/proto/server/MultiMatchQueryProto.proto b/server/src/main/proto/server/MultiMatchQueryProto.proto new file mode 100644 index 0000000000000..debbbe8624080 --- /dev/null +++ b/server/src/main/proto/server/MultiMatchQueryProto.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; +package org.opensearch.server.proto; + +message MultiMatchQuery { + optional string type = 1; + optional string analyzer = 2; + optional float cutoffFrequency = 3; + optional FieldAndBoost fieldAndBoost = 4; + // TODO: add more +} + +message FieldAndBoost { + optional string field = 1; + optional float boost = 2; +} diff --git a/server/src/main/proto/server/SearchRequestProto.proto b/server/src/main/proto/server/SearchRequestProto.proto new file mode 100644 index 0000000000000..134521f24476e --- /dev/null +++ b/server/src/main/proto/server/SearchRequestProto.proto @@ -0,0 +1,91 @@ +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/MatchQueryProto.proto"; +import "server/TermQueryProto.proto"; +import "server/BoolQueryProto.proto"; + +message SearchRequest { + optional string localClusterAlias = 1; + optional int64 absoluteStartMillis = 2; + optional bool finalReduce = 3; + optional string routing = 5; + optional SearchType searchType = 6; + optional string preference = 7; + optional bool requestCache = 8; + optional bool allowPartialSearchResults = 9; + optional Scroll scroll = 10; + optional int32 preFilterShardSize = 11; + optional IndicesOptions indicesOptions = 12; + optional string pipeline = 13; + optional bool phaseTook = 14; + optional int32 batchedReduceSize = 15; + optional bool ccsMinimizeRoundtrips = 16; + optional int32 maxConcurrentShardRequests = 17; + optional int32 explain = 18; + repeated string indices = 19; + optional SourceBuilder sourceBuilder = 20; + + message SourceBuilder { + optional int32 from = 1; + optional int32 size = 2; + optional int32 timeout = 3; + optional int32 terminateAfter = 4; + optional string analyzer = 5; + optional int32 trackTotalHitsUpToInt = 6; + optional bool trackTotalHitsUpToBool = 7; + optional FetchSourceContext source = 8; + optional bool explain = 9; + optional bool version = 10; + optional QueryBuilder query = 11; + // optional indicesBoost + // SUGGEST_FIELD + // HIGHLIGHT_FIELD + // SUGGEST_FIELD + // SORT_FIELD + // AGGREGATIONS_FIELD + // SLICE + // COLLAPSE + // TODO: add more... + + message QueryBuilder { + optional float boost = 1; + optional BoolQuery bool = 2; + } + + } + + message FetchSourceContext { + // _source + + } + message StoredField { + + } + + message DocValueField { + optional string field = 1; + optional string format = 2; + } +} + +enum SearchType { + QUERY_THEN_FETCH = 0; + DFS_QUERY_THEN_FETCH = 1; +} + +message Scroll { + string keepAlive = 1; +} + +message IndicesOptions { + bool ignoreUnavailable = 1; + bool allowNoIndices = 2; + bool expandWildcardsOpen = 3; + bool expandWildcardsClosed = 4; + bool expandWildcardsHidden = 5; + bool allowAliasesToMultipleIndices = 6; + bool forbidClosedIndices = 7; + bool ignoreAliases = 8; + bool ignoreThrottled = 9; +} diff --git a/server/src/main/proto/server/TermQueryProto.proto b/server/src/main/proto/server/TermQueryProto.proto new file mode 100644 index 0000000000000..b40941fc81fc6 --- /dev/null +++ b/server/src/main/proto/server/TermQueryProto.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; +package org.opensearch.server.proto; + +message TermQuery { + optional string name = 1; + optional string fieldName = 2; + optional bool caseInsensitive = 3; + // TODO: change to object + optional string value = 4; +}