Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] bug fixes for correlation Alerts #680

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/main/kotlin/org/opensearch/commons/alerting/model/BaseAlert.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package org.opensearch.commons.alerting.model
import org.opensearch.common.lucene.uid.Versions
import org.opensearch.commons.alerting.util.IndexUtils.Companion.NO_SCHEMA_VERSION
import org.opensearch.commons.alerting.util.instant
import org.opensearch.commons.alerting.util.optionalTimeField
import org.opensearch.commons.alerting.util.optionalUserField
import org.opensearch.commons.authuser.User
import org.opensearch.core.common.io.stream.StreamInput
Expand Down Expand Up @@ -85,17 +84,17 @@ open class BaseAlert(

companion object {
const val ALERT_ID_FIELD = "id"
const val SCHEMA_VERSION_FIELD = "schemaVersion"
const val SCHEMA_VERSION_FIELD = "schema_version"
const val ALERT_VERSION_FIELD = "version"
const val USER_FIELD = "user"
const val TRIGGER_NAME_FIELD = "triggerName"
const val TRIGGER_NAME_FIELD = "trigger_name"
const val STATE_FIELD = "state"
const val START_TIME_FIELD = "startTime"
const val END_TIME_FIELD = "endTime"
const val ACKNOWLEDGED_TIME_FIELD = "acknowledgedTime"
const val ERROR_MESSAGE_FIELD = "errorMessage"
const val START_TIME_FIELD = "start_time"
const val END_TIME_FIELD = "end_time"
const val ACKNOWLEDGED_TIME_FIELD = "acknowledged_time"
const val ERROR_MESSAGE_FIELD = "error_message"
const val SEVERITY_FIELD = "severity"
const val ACTION_EXECUTION_RESULTS_FIELD = "actionExecutionResults"
const val ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results"
const val NO_ID = ""
const val NO_VERSION = Versions.NOT_FOUND

Expand Down Expand Up @@ -138,7 +137,7 @@ open class BaseAlert(
}
}
START_TIME_FIELD -> startTime = requireNotNull(xcp.instant())
END_TIME_FIELD -> endTime = xcp.instant()
END_TIME_FIELD -> endTime = requireNotNull(xcp.instant())
ACKNOWLEDGED_TIME_FIELD -> acknowledgedTime = xcp.instant()
}
}
Expand Down Expand Up @@ -178,17 +177,18 @@ open class BaseAlert(
if (!secure) {
builder.optionalUserField(USER_FIELD, user)
}
builder.field(ALERT_ID_FIELD, id)
builder
.field(ALERT_ID_FIELD, id)
.field(ALERT_VERSION_FIELD, version)
.field(SCHEMA_VERSION_FIELD, schemaVersion)
.field(TRIGGER_NAME_FIELD, triggerName)
.field(STATE_FIELD, state)
.field(ERROR_MESSAGE_FIELD, errorMessage)
.field(SEVERITY_FIELD, severity)
.field(ACTION_EXECUTION_RESULTS_FIELD, actionExecutionResults.toTypedArray())
.optionalTimeField(START_TIME_FIELD, startTime)
.optionalTimeField(END_TIME_FIELD, endTime)
.optionalTimeField(ACKNOWLEDGED_TIME_FIELD, acknowledgedTime)
.field(START_TIME_FIELD, startTime)
.field(END_TIME_FIELD, endTime)
.field(ACKNOWLEDGED_TIME_FIELD, acknowledgedTime)
return builder
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.opensearch.commons.alerting.model
import org.opensearch.commons.authuser.User
import org.opensearch.core.common.io.stream.StreamInput
import org.opensearch.core.common.io.stream.StreamOutput
import org.opensearch.core.xcontent.ToXContent
import org.opensearch.core.xcontent.XContentBuilder
import org.opensearch.core.xcontent.XContentParser
import org.opensearch.core.xcontent.XContentParserUtils
Expand Down Expand Up @@ -59,7 +60,7 @@ class CorrelationAlert : BaseAlert {
}

// Override to include CorrelationAlert specific fields
fun toXContent(builder: XContentBuilder): XContentBuilder {
override fun toXContent(builder: XContentBuilder, params: ToXContent.Params): XContentBuilder {
builder.startObject()
.startArray(CORRELATED_FINDING_IDS)
correlatedFindingIds.forEach { id ->
Expand Down Expand Up @@ -90,9 +91,9 @@ class CorrelationAlert : BaseAlert {
return superTemplateArgs + correlationSpecificArgs
}
companion object {
const val CORRELATED_FINDING_IDS = "correlatedFindingIds"
const val CORRELATION_RULE_ID = "correlationRuleId"
const val CORRELATION_RULE_NAME = "correlationRuleName"
const val CORRELATED_FINDING_IDS = "correlated_finding_ids"
const val CORRELATION_RULE_ID = "correlation_rule_id"
const val CORRELATION_RULE_NAME = "correlation_rule_name"

@JvmStatic
@Throws(IOException::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@ package org.opensearch.commons.alerting
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.opensearch.common.xcontent.LoggingDeprecationHandler
import org.opensearch.common.xcontent.XContentHelper
import org.opensearch.commons.alerting.model.Alert
import org.opensearch.commons.alerting.model.CorrelationAlert
import org.opensearch.commons.utils.getJsonString
import org.opensearch.commons.utils.recreateObject
import org.opensearch.core.common.bytes.BytesArray
import org.opensearch.core.common.bytes.BytesReference
import org.opensearch.core.common.io.stream.InputStreamStreamInput
import org.opensearch.core.xcontent.NamedXContentRegistry
import java.time.temporal.ChronoUnit

class CorrelationAlertTests {
Expand All @@ -26,17 +19,17 @@ class CorrelationAlertTests {
val templateArgs = createCorrelationAlertTemplateArgs(correlationAlert)

assertEquals(
templateArgs["correlatedFindingIds"],
templateArgs["correlated_finding_ids"],
correlationAlert.correlatedFindingIds,
"Template args correlatedFindingIds does not match"
)
assertEquals(
templateArgs["correlationRuleId"],
templateArgs["correlation_rule_id"],
correlationAlert.correlationRuleId,
"Template args correlationRuleId does not match"
)
assertEquals(
templateArgs["correlationRuleName"],
templateArgs["correlation_rule_name"],
correlationAlert.correlationRuleName,
"Template args correlationRuleName does not match"
)
Expand All @@ -46,26 +39,26 @@ class CorrelationAlertTests {
assertEquals(templateArgs["version"], correlationAlert.version, "Template args version does not match")
assertEquals(templateArgs["user"], correlationAlert.user, "Template args user does not match")
assertEquals(
templateArgs["triggerName"],
templateArgs["trigger_name"],
correlationAlert.triggerName,
"Template args triggerName does not match"
)
assertEquals(templateArgs["state"], correlationAlert.state, "Template args state does not match")
assertEquals(templateArgs["startTime"], correlationAlert.startTime, "Template args startTime does not match")
assertEquals(templateArgs["endTime"], correlationAlert.endTime, "Template args endTime does not match")
assertEquals(templateArgs["start_time"], correlationAlert.startTime, "Template args startTime does not match")
assertEquals(templateArgs["end_time"], correlationAlert.endTime, "Template args endTime does not match")
assertEquals(
templateArgs["acknowledgedTime"],
templateArgs["acknowledged_time"],
correlationAlert.acknowledgedTime,
"Template args acknowledgedTime does not match"
)
assertEquals(
templateArgs["errorMessage"],
templateArgs["error_message"],
correlationAlert.errorMessage,
"Template args errorMessage does not match"
)
assertEquals(templateArgs["severity"], correlationAlert.severity, "Template args severity does not match")
assertEquals(
templateArgs["actionExecutionResults"],
templateArgs["action_execution_results"],
correlationAlert.actionExecutionResults,
"Template args actionExecutionResults does not match"
)
Expand All @@ -80,37 +73,6 @@ class CorrelationAlertTests {
Assertions.assertFalse(activeCorrelationAlert.isAcknowledged(), "Alert is acknowledged")
}

@Test
fun `test correlation parse function`() {
// Generate a random CorrelationAlert object
val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE)
val correlationAlertString = getJsonString(correlationAlert)

// Convert the JSON string to a BytesReference
val serializedBytes: BytesReference = BytesArray(correlationAlertString.toByteArray(Charsets.UTF_8))

// Deserialize the BytesReference into a CorrelationAlert object using the parse function
val recreatedAlert: CorrelationAlert = InputStreamStreamInput(serializedBytes.streamInput()).use { streamInput ->
XContentHelper.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, serializedBytes).use { parser ->
parser.nextToken() // Move to the start of the content
CorrelationAlert.parse(parser)
}
}

// Assert that the deserialized object matches the original object
assertEquals(correlationAlert.correlatedFindingIds, recreatedAlert.correlatedFindingIds)
assertEquals(correlationAlert.correlationRuleId, recreatedAlert.correlationRuleId)
assertEquals(correlationAlert.correlationRuleName, recreatedAlert.correlationRuleName)
assertEquals(correlationAlert.triggerName, recreatedAlert.triggerName)
assertEquals(correlationAlert.state, recreatedAlert.state)
val expectedStartTime = correlationAlert.startTime.truncatedTo(ChronoUnit.MILLIS)
val actualStartTime = recreatedAlert.startTime.truncatedTo(ChronoUnit.MILLIS)
assertEquals(expectedStartTime, actualStartTime)
assertEquals(correlationAlert.severity, recreatedAlert.severity)
assertEquals(correlationAlert.id, recreatedAlert.id)
assertEquals(correlationAlert.actionExecutionResults, recreatedAlert.actionExecutionResults)
}

@Test
fun `Feature Correlation Alert serialize and deserialize should be equal`() {
val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ fun createUnifiedAlertTemplateArgs(unifiedAlert: BaseAlert): Map<String, Any?> {
fun createCorrelationAlertTemplateArgs(correlationAlert: CorrelationAlert): Map<String, Any?> {
val unifiedAlertTemplateArgs = createUnifiedAlertTemplateArgs(correlationAlert)
return unifiedAlertTemplateArgs + mapOf(
"correlatedFindingIds" to correlationAlert.correlatedFindingIds,
"correlationRuleId" to correlationAlert.correlationRuleId,
"correlationRuleName" to correlationAlert.correlationRuleName
CorrelationAlert.CORRELATED_FINDING_IDS to correlationAlert.correlatedFindingIds,
CorrelationAlert.CORRELATION_RULE_ID to correlationAlert.correlationRuleId,
CorrelationAlert.CORRELATION_RULE_NAME to correlationAlert.correlationRuleName
)
}
7 changes: 1 addition & 6 deletions src/test/kotlin/org/opensearch/commons/utils/TestHelpers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.opensearch.commons.utils

import org.opensearch.common.xcontent.XContentFactory
import org.opensearch.common.xcontent.XContentType
import org.opensearch.commons.alerting.model.CorrelationAlert
import org.opensearch.core.xcontent.DeprecationHandler
import org.opensearch.core.xcontent.NamedXContentRegistry
import org.opensearch.core.xcontent.ToXContent
Expand All @@ -17,11 +16,7 @@ import java.io.ByteArrayOutputStream
fun getJsonString(xContent: ToXContent): String {
ByteArrayOutputStream().use { byteArrayOutputStream ->
val builder = XContentFactory.jsonBuilder(byteArrayOutputStream)
if (xContent is CorrelationAlert) {
xContent.toXContent(builder)
} else {
xContent.toXContent(builder, ToXContent.EMPTY_PARAMS)
}
xContent.toXContent(builder, ToXContent.EMPTY_PARAMS)
builder.close()
return byteArrayOutputStream.toString("UTF8")
}
Expand Down
Loading