Skip to content

Commit

Permalink
add proto for boolean, match, serachrequest, term, mutimatch
Browse files Browse the repository at this point in the history
  • Loading branch information
zshuyi committed Aug 1, 2024
1 parent 5b93f2e commit f64ab14
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.search.builder.PointInTimeBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.server.proto.SearchRequestProto;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -274,6 +275,18 @@ public SearchRequest(StreamInput in) throws IOException {
}
}

public SearchRequest(byte[] in) throws IOException {
this();
SearchRequestProto.SearchRequest searchRequestProto = SearchRequestProto.SearchRequest.parseFrom(in);
indices = searchRequestProto.getIndicesList().toArray(new String[0]);
routing = searchRequestProto.getRouting();
preference = searchRequestProto.getPreference();
batchedReduceSize = searchRequestProto.getBatchedReduceSize();
source = new SearchSourceBuilder(searchRequestProto.getSourceBuilder());
searchType = SearchType.QUERY_THEN_FETCH;
batchedReduceSize = DEFAULT_BATCHED_REDUCE_SIZE;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ protected AbstractQueryBuilder(StreamInput in) throws IOException {
queryName = in.readOptionalString();
}

protected AbstractQueryBuilder(String queryName, float boost) throws IOException {
this.boost = boost;
checkNegativeBoost(boost);
this.queryName = queryName;
}

@Override
public final void writeTo(StreamOutput out) throws IOException {
out.writeFloat(boost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.index.query.support.QueryParsers;
import org.opensearch.index.search.MatchQuery;
import org.opensearch.index.search.MatchQuery.ZeroTermsQuery;
import org.opensearch.server.proto.MatchQueryProto;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -129,6 +130,14 @@ public MatchQueryBuilder(String fieldName, Object value) {
this.value = value;
}

public MatchQueryBuilder(MatchQueryProto.MatchQuery queryProto) throws IOException {
super("", DEFAULT_BOOST);
fieldName = queryProto.getFieldName();
value = queryProto.getValue();
boost = DEFAULT_BOOST;
queryName = "match";
}

/**
* Read from a stream.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.ConstantFieldType;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.server.proto.MatchQueryProto;
import org.opensearch.server.proto.TermQueryProto;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -117,6 +119,16 @@ public TermQueryBuilder(StreamInput in) throws IOException {
}
}


public TermQueryBuilder(TermQueryProto.TermQuery queryProto) throws IOException {
super(queryProto.getFieldName(), queryProto.getValue());
boost(DEFAULT_BOOST);
if (queryProto.hasCaseInsensitive()) {
caseInsensitive(queryProto.getCaseInsensitive());
}
}


@Override
protected void doWriteTo(StreamOutput out) throws IOException {
super.doWriteTo(out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
*/

package org.opensearch.rest.action.search;

import java.util.Map;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchContextId;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.Booleans;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
Expand All @@ -49,7 +50,6 @@
import org.opensearch.rest.action.RestActions;
import org.opensearch.rest.action.RestCancellableNodeClient;
import org.opensearch.rest.action.RestStatusToXContentListener;
import org.opensearch.search.Scroll;
import org.opensearch.search.SearchService;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.StoredFieldsContext;
Expand All @@ -60,17 +60,12 @@
import org.opensearch.search.suggest.term.TermSuggestionBuilder.SuggestMode;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.function.IntConsumer;

import static java.util.Arrays.asList;
import static java.util.Collections.unmodifiableList;
import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.common.unit.TimeValue.parseTimeValue;
import static org.opensearch.rest.RestRequest.Method.GET;
import static org.opensearch.rest.RestRequest.Method.POST;
import static org.opensearch.search.suggest.SuggestBuilders.termSuggestion;
Expand Down Expand Up @@ -106,17 +101,17 @@ public String getName() {
public List<Route> routes() {
return unmodifiableList(
asList(
new Route(GET, "/_search"),
new Route(POST, "/_search"),
new Route(GET, "/{index}/_search"),
new Route(POST, "/{index}/_search")
new Route(GET, "/_search_proto"),
new Route(POST, "/_search_proto"),
new Route(GET, "/{index}/_search_proto"),
new Route(POST, "/{index}/_search_proto")
)
);
}

@Override
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
SearchRequest searchRequest = new SearchRequest();
// SearchRequest searchRequest = new SearchRequest();
/*
* We have to pull out the call to `source().size(size)` because
* _update_by_query and _delete_by_query uses this same parsing
Expand All @@ -129,11 +124,12 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
* be null later. If that is confusing to you then you are in good
* company.
*/
IntConsumer setSize = size -> searchRequest.source().size(size);
request.withContentOrSourceParamParserOrNull(
parser -> parseSearchRequest(searchRequest, request, parser, client.getNamedWriteableRegistry(), setSize)
);
BytesReference content = request.content();
byte[] contentBytes = BytesReference.toBytes(content);
SearchRequest searchRequest = new SearchRequest(contentBytes);

IntConsumer setSize = size -> searchRequest.source().size(size);
request.param("index");
return channel -> {
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
Expand All @@ -147,6 +143,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
* parameter
* @param setSize how the size url parameter is handled. {@code udpate_by_query} and regular search differ here.
*/
// SKIP as we use proto
public static void parseSearchRequest(
SearchRequest searchRequest,
RestRequest request,
Expand All @@ -155,75 +152,77 @@ public static void parseSearchRequest(
IntConsumer setSize
) throws IOException {

if (searchRequest.source() == null) {
searchRequest.source(new SearchSourceBuilder());
}
searchRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
if (requestContentParser != null) {
searchRequest.source().parseXContent(requestContentParser, true);
}

final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize());
searchRequest.setBatchedReduceSize(batchedReduceSize);
if (request.hasParam("pre_filter_shard_size")) {
searchRequest.setPreFilterShardSize(request.paramAsInt("pre_filter_shard_size", SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE));
}

if (request.hasParam("max_concurrent_shard_requests")) {
// only set if we have the parameter since we auto adjust the max concurrency on the coordinator
// based on the number of nodes in the cluster
final int maxConcurrentShardRequests = request.paramAsInt(
"max_concurrent_shard_requests",
searchRequest.getMaxConcurrentShardRequests()
);
searchRequest.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
}

if (request.hasParam("allow_partial_search_results")) {
// only set if we have the parameter passed to override the cluster-level default
searchRequest.allowPartialSearchResults(request.paramAsBoolean("allow_partial_search_results", null));
}

if (request.hasParam("phase_took")) {
// only set if we have the parameter passed to override the cluster-level default
// else phaseTook = null
searchRequest.setPhaseTook(request.paramAsBoolean("phase_took", true));
}

// do not allow 'query_and_fetch' or 'dfs_query_and_fetch' search types
// from the REST layer. these modes are an internal optimization and should
// not be specified explicitly by the user.
String searchType = request.param("search_type");
if ("query_and_fetch".equals(searchType) || "dfs_query_and_fetch".equals(searchType)) {
throw new IllegalArgumentException("Unsupported search type [" + searchType + "]");
} else {
searchRequest.searchType(searchType);
}
parseSearchSource(searchRequest.source(), request, setSize);
searchRequest.requestCache(request.paramAsBoolean("request_cache", searchRequest.requestCache()));

String scroll = request.param("scroll");
if (scroll != null) {
searchRequest.scroll(new Scroll(parseTimeValue(scroll, null, "scroll")));
}

searchRequest.routing(request.param("routing"));
searchRequest.preference(request.param("preference"));
searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
searchRequest.pipeline(request.param("search_pipeline"));

checkRestTotalHits(request, searchRequest);
request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);

if (searchRequest.pointInTimeBuilder() != null) {
preparePointInTime(searchRequest, request, namedWriteableRegistry);
} else {
searchRequest.setCcsMinimizeRoundtrips(
request.paramAsBoolean("ccs_minimize_roundtrips", searchRequest.isCcsMinimizeRoundtrips())
);
}

searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));
// if (searchRequest.source() == null) {
// searchRequest.source(new SearchSourceBuilder());
// }
// System.out.println("========= request.param(\"index\") " + request.param("index"));
// searchRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
// if (requestContentParser != null) {
// searchRequest.source().parseXContent(requestContentParser, true);
// }
//
// final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize());
// searchRequest.setBatchedReduceSize(batchedReduceSize);
// if (request.hasParam("pre_filter_shard_size")) {
// searchRequest.setPreFilterShardSize(request.paramAsInt("pre_filter_shard_size", SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE));
// }
//
// if (request.hasParam("max_concurrent_shard_requests")) {
// // only set if we have the parameter since we auto adjust the max concurrency on the coordinator
// // based on the number of nodes in the cluster
// final int maxConcurrentShardRequests = request.paramAsInt(
// "max_concurrent_shard_requests",
// searchRequest.getMaxConcurrentShardRequests()
// );
// searchRequest.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
// }
//
// if (request.hasParam("allow_partial_search_results")) {
// // only set if we have the parameter passed to override the cluster-level default
// searchRequest.allowPartialSearchResults(request.paramAsBoolean("allow_partial_search_results", null));
// }
//
// if (request.hasParam("phase_took")) {
// // only set if we have the parameter passed to override the cluster-level default
// // else phaseTook = null
// searchRequest.setPhaseTook(request.paramAsBoolean("phase_took", true));
// }
//
// // do not allow 'query_and_fetch' or 'dfs_query_and_fetch' search types
// // from the REST layer. these modes are an internal optimization and should
// // not be specified explicitly by the user.
// String searchType = request.param("search_type");
// if ("query_and_fetch".equals(searchType) || "dfs_query_and_fetch".equals(searchType)) {
// throw new IllegalArgumentException("Unsupported search type [" + searchType + "]");
// } else {
// searchRequest.searchType(searchType);
// }
// System.out.println("================ setSize: " + setSize);
// parseSearchSource(searchRequest.source(), request, setSize);
// searchRequest.requestCache(request.paramAsBoolean("request_cache", searchRequest.requestCache()));
//
// String scroll = request.param("scroll");
// if (scroll != null) {
// searchRequest.scroll(new Scroll(parseTimeValue(scroll, null, "scroll")));
// }
//
// searchRequest.routing(request.param("routing"));
// searchRequest.preference(request.param("preference"));
// searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
// searchRequest.pipeline(request.param("search_pipeline"));
//
// checkRestTotalHits(request, searchRequest);
// request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);
//
// if (searchRequest.pointInTimeBuilder() != null) {
// preparePointInTime(searchRequest, request, namedWriteableRegistry);
// } else {
// searchRequest.setCcsMinimizeRoundtrips(
// request.paramAsBoolean("ccs_minimize_roundtrips", searchRequest.isCcsMinimizeRoundtrips())
// );
// }
//
// searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));
}

/**
Expand Down Expand Up @@ -304,6 +303,7 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil
}
}

// searchSource sort
String sSorts = request.param("sort");
if (sSorts != null) {
String[] sorts = Strings.splitStringByCommaToArray(sSorts);
Expand All @@ -322,7 +322,6 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil
}
}
}

String sStats = request.param("stats");
if (sStats != null) {
searchSourceBuilder.stats(Arrays.asList(Strings.splitStringByCommaToArray(sStats)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.DerivedField;
import org.opensearch.index.mapper.DerivedFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.server.proto.BoolQueryProto;
import org.opensearch.server.proto.MatchQueryProto;
import org.opensearch.server.proto.SearchRequestProto;
import org.opensearch.server.proto.TermQueryProto;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.Rewriteable;
Expand Down Expand Up @@ -230,6 +238,40 @@ public static HighlightBuilder highlight() {
*/
public SearchSourceBuilder() {}

public SearchSourceBuilder(SearchRequestProto.SearchRequest.SourceBuilder sourceBuilderProto) throws IOException {
this();
if (sourceBuilderProto.hasFrom()) {
from = sourceBuilderProto.getFrom();
}
if (sourceBuilderProto.hasSize()) {
size = sourceBuilderProto.getSize();
}
if (sourceBuilderProto.hasTerminateAfter()) {
terminateAfter = sourceBuilderProto.getTerminateAfter();
}
if (sourceBuilderProto.hasQuery() && sourceBuilderProto.getQuery().hasBool()) {
BoolQueryProto.BoolQuery boolQueryProto = sourceBuilderProto.getQuery().getBool();
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Handle must clauses
for (MatchQueryProto.MatchQuery matchQueryProto : boolQueryProto.getMustList()) {
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(matchQueryProto.getFieldName(), matchQueryProto.getValue());
boolQueryBuilder.must(matchQueryBuilder);
}

// Handle filter clauses
for (TermQueryProto.TermQuery termQueryProto : boolQueryProto.getFilterList()) {
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(termQueryProto.getFieldName(), termQueryProto.getValue());
boolQueryBuilder.filter(termQueryBuilder);
}

queryBuilder = boolQueryBuilder;
} else {
throw new IllegalArgumentException("Expected a BoolQuery in the source builder");
}

}

/**
* Read from a stream.
*/
Expand Down
Loading

0 comments on commit f64ab14

Please sign in to comment.