Skip to content

Commit

Permalink
add batch processing
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <ihailong@amazon.com>
  • Loading branch information
Hailong-am committed Jan 19, 2024
1 parent d3f3dfb commit 4d41090
Showing 1 changed file with 31 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.common.collect.Lists;
import com.google.common.hash.Hashing;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
Expand Down Expand Up @@ -125,23 +126,26 @@ public void run() {

for (String modelId : embeddingModelIds) {
try {
List<String> embeddingDocs = indexMetaAndSamples
.stream()
.map(sample -> (String) sample.get(INDEX_SUMMARY))
.collect(Collectors.toList());

List<ModelTensors> mlModelOutputs = mlClients.getEmbeddingResult(modelId, embeddingDocs, true, mlTaskResponse -> {
ModelTensorOutput output = (ModelTensorOutput) mlTaskResponse.getOutput();
return output.getMlModelOutputs();
});

for (int i = 0; i < mlModelOutputs.size(); i++) {
Number[] vector = mlModelOutputs.get(i).getMlModelTensors().get(0).getData();
indexMetaAndSamples.get(i).put(INDEX_EMBEDDING, vector);
List<List<Map<String, Object>>> partitions = Lists.partition(indexMetaAndSamples, 1000);

Check warning on line 129 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L129

Added line #L129 was not covered by tests
for (List<Map<String, Object>> partition : partitions) {
List<String> embeddingDocs = partition
.stream()
.map(sample -> (String) sample.get(INDEX_SUMMARY))
.collect(Collectors.toList());

Check warning on line 134 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L131-L134

Added lines #L131 - L134 were not covered by tests

List<ModelTensors> mlModelOutputs = mlClients.getEmbeddingResult(modelId, embeddingDocs, true, mlTaskResponse -> {
ModelTensorOutput output = (ModelTensorOutput) mlTaskResponse.getOutput();
return output.getMlModelOutputs();

Check warning on line 138 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L136-L138

Added lines #L136 - L138 were not covered by tests
});

for (int i = 0; i < mlModelOutputs.size(); i++) {
Number[] vector = mlModelOutputs.get(i).getMlModelTensors().get(0).getData();
partition.get(i).put(INDEX_EMBEDDING, vector);

Check warning on line 143 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L142-L143

Added lines #L142 - L143 were not covered by tests
}

// write to k-NN index
indexSummaryVector(INDEX_SUMMARY_EMBEDDING_INDEX, partition, modelId);
}

// write to k-NN index
indexSummaryVector(INDEX_SUMMARY_EMBEDDING_INDEX, indexMetaAndSamples, modelId);
} catch (Exception e) {
log.error("Failed to embedding index summary for model {}", modelId);
}
Expand Down Expand Up @@ -179,6 +183,7 @@ private List<Map<String, Object>> getAllIndexMappingAndSampleData() {
for (ComposableIndexTemplate composableIndexTemplate : metadata.templatesV2().values()) {
patterns.addAll(composableIndexTemplate.indexPatterns());
}

Check warning on line 185 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L184-L185

Added lines #L184 - L185 were not covered by tests
// TODO leverage index-pattern in OSD

List<Map<String, Object>> indexSummaryList = new ArrayList<>();

Check warning on line 188 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L188

Added line #L188 was not covered by tests

Expand Down Expand Up @@ -227,6 +232,7 @@ private List<Map<String, Object>> getAllIndexMappingAndSampleData() {

for (SearchHit hit : searchResponse.getHits()) {
String docContent = Strings.toString(MediaTypeRegistry.JSON, hit);

Check warning on line 234 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L234

Added line #L234 was not covered by tests
// TODO Remove long content field and knn field
sampleDataList.add(docContent);
}

Check warning on line 237 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L236-L237

Added lines #L236 - L237 were not covered by tests

Expand All @@ -251,6 +257,9 @@ private List<Map<String, Object>> getAllIndexMappingAndSampleData() {
String indexSummary = String
.format(Locale.ROOT, "Index Mappings:%s\\nSample data:\\n%s", mapping, sampleDataList.subList(0, nSample));
map.put(INDEX_SUMMARY, indexSummary);

Check warning on line 259 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L255-L259

Added lines #L255 - L259 were not covered by tests
// remove keys are not used
map.remove(MAPPING);
map.remove(SAMPLE_DATA);
}

Check warning on line 263 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L261-L263

Added lines #L261 - L263 were not covered by tests

return indexSummaryList;

Check warning on line 265 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L265

Added line #L265 was not covered by tests
Expand Down Expand Up @@ -324,9 +333,7 @@ private void bulkWrite(String writeIndex, List<Map<String, Object>> docs, String
docMap.put(INDEX_SUMMARY_FIELD, doc.get(INDEX_SUMMARY));
docMap.put(INDEX_SUMMARY_EMBEDDING_FIELD_PREFIX + "_" + modelId, embedding);

Check warning on line 334 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L327-L334

Added lines #L327 - L334 were not covered by tests

String docId = Hashing.sha256().hashString(indexName, StandardCharsets.UTF_8).toString();

bulkRequest.add(new UpdateRequest(writeIndex, docId).doc(docMap, MediaTypeRegistry.JSON).docAsUpsert(true));
bulkRequest.add(new UpdateRequest(writeIndex, generateDocId(indexName)).doc(docMap, MediaTypeRegistry.JSON).docAsUpsert(true));
}
client.bulk(bulkRequest, ActionListener.wrap(r -> {

Check warning on line 338 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L336-L338

Added lines #L336 - L338 were not covered by tests
if (r.hasFailures()) {
Expand All @@ -340,8 +347,7 @@ private void bulkWrite(String writeIndex, List<Map<String, Object>> docs, String
public void bulkDelete(String writeIndex, List<String> indexNames) {
BulkRequest bulkRequest = Requests.bulkRequest();

Check warning on line 348 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L348

Added line #L348 was not covered by tests
for (String indexName : indexNames) {
String docId = String.valueOf(indexName.hashCode());
bulkRequest.add(new DeleteRequest(writeIndex, docId));
bulkRequest.add(new DeleteRequest(writeIndex, generateDocId(indexName)));
}
client.bulk(bulkRequest, ActionListener.wrap(r -> {

Check warning on line 352 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L350-L352

Added lines #L350 - L352 were not covered by tests
if (r.hasFailures()) {
Expand All @@ -351,4 +357,8 @@ public void bulkDelete(String writeIndex, List<String> indexNames) {
}
}, exception -> log.error("Bulk delete index summary embedding failed", exception)));
}

Check warning on line 359 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L358-L359

Added lines #L358 - L359 were not covered by tests

private String generateDocId(String indexName) {
return Hashing.sha256().hashString(indexName, StandardCharsets.UTF_8).toString();

Check warning on line 362 in src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/agent/job/IndexSummaryEmbeddingJob.java#L362

Added line #L362 was not covered by tests
}
}

0 comments on commit 4d41090

Please sign in to comment.