diff --git a/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/SelectQueryGenerator.kt b/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/SelectQueryGenerator.kt index d4dd397fa23..32a5c9f4046 100644 --- a/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/SelectQueryGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/SelectQueryGenerator.kt @@ -15,6 +15,8 @@ */ package com.squareup.sqldelight.core.compiler +import com.alecstrong.sql.psi.core.psi.SqlColumnDef +import com.alecstrong.sql.psi.core.psi.SqlCreateTableStmt import com.squareup.kotlinpoet.ANY import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FunSpec @@ -32,14 +34,18 @@ import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.TypeVariableName import com.squareup.kotlinpoet.joinToCode +import com.squareup.sqldelight.core.compiler.SqlDelightCompiler.allocateName import com.squareup.sqldelight.core.compiler.model.NamedQuery +import com.squareup.sqldelight.core.lang.ADAPTER_NAME import com.squareup.sqldelight.core.lang.CURSOR_NAME import com.squareup.sqldelight.core.lang.CURSOR_TYPE +import com.squareup.sqldelight.core.lang.CUSTOM_DATABASE_NAME import com.squareup.sqldelight.core.lang.DRIVER_NAME import com.squareup.sqldelight.core.lang.EXECUTE_METHOD import com.squareup.sqldelight.core.lang.MAPPER_NAME import com.squareup.sqldelight.core.lang.QUERY_LIST_TYPE import com.squareup.sqldelight.core.lang.QUERY_TYPE +import com.squareup.sqldelight.core.lang.psi.ColumnTypeMixin import com.squareup.sqldelight.core.lang.util.rawSqlText class SelectQueryGenerator(private val query: NamedQuery) : QueryGenerator(query) { @@ -166,6 +172,26 @@ class SelectQueryGenerator(private val query: NamedQuery) : QueryGenerator(query val function = customResultTypeFunctionInterface() .addModifiers(OVERRIDE) + query.resultColumns.forEach { resultColumn -> + (listOf(resultColumn) + resultColumn.assumedCompatibleTypes) + .takeIf { it.size > 1 } + ?.map { assumedCompatibleType -> + (assumedCompatibleType.column?.columnType as ColumnTypeMixin?)?.let { columnTypeMixin -> + val tableAdapterName = "${(assumedCompatibleType.column!!.parent as SqlCreateTableStmt).name()}$ADAPTER_NAME" + val columnAdapterName = "${allocateName((columnTypeMixin.parent as SqlColumnDef).columnName)}$ADAPTER_NAME" + "$CUSTOM_DATABASE_NAME.$tableAdapterName.$columnAdapterName" + } + } + ?.let { adapterNames -> + function.addStatement( + """%M(%M(%L).size == 1) { "Adapter·types·are·expected·to·be·identical." }""", + MemberName("kotlin", "check"), + MemberName("kotlin.collections", "setOf"), + adapterNames.joinToString() + ) + } + } + // Assemble the actual mapper lambda: // { resultSet -> // mapper( diff --git a/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/model/NamedQuery.kt b/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/model/NamedQuery.kt index a1ea3302684..012592017dd 100644 --- a/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/model/NamedQuery.kt +++ b/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/compiler/model/NamedQuery.kt @@ -162,7 +162,11 @@ data class NamedQuery( typeOne.column != null && typeTwo.column != null ) { // Incompatible adapters. Revert to unadapted java type. - return IntermediateType(dialectType = typeOne.dialectType, name = typeOne.name).nullableIf(nullable) + return if (typeOne.javaType.copy(nullable = false) == typeTwo.javaType.copy(nullable = false)) { + typeOne.copy(assumedCompatibleTypes = typeOne.assumedCompatibleTypes + typeTwo).nullableIf(nullable) + } else { + IntermediateType(dialectType = typeOne.dialectType, name = typeOne.name).nullableIf(nullable) + } } return typeOne.nullableIf(nullable) diff --git a/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/lang/IntermediateType.kt b/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/lang/IntermediateType.kt index 64e926a51a5..45a554bf5b0 100644 --- a/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/lang/IntermediateType.kt +++ b/sqldelight-compiler/src/main/kotlin/com/squareup/sqldelight/core/lang/IntermediateType.kt @@ -55,7 +55,11 @@ internal data class IntermediateType( /** * Whether or not this argument is extracted from a different type */ - val extracted: Boolean = false + val extracted: Boolean = false, + /** + * The types assumed to be compatible with this type. Validated at runtime. + */ + val assumedCompatibleTypes: List = emptyList(), ) { fun asNullable() = copy(javaType = javaType.copy(nullable = true)) diff --git a/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/InterfaceGeneration.kt b/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/InterfaceGeneration.kt index b9084be7cca..f1133d05c4e 100644 --- a/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/InterfaceGeneration.kt +++ b/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/InterfaceGeneration.kt @@ -101,7 +101,7 @@ class InterfaceGeneration { |); | |CREATE TABLE B( - | value TEXT AS kotlin.collections.List + | value TEXT AS kotlin.collections.Set |); | |unionOfBoth: @@ -167,6 +167,45 @@ class InterfaceGeneration { ) } + @Test fun `compatible adapter types from different columns merges nullability`() { + val file = FixtureCompiler.parseSql( + """ + |CREATE TABLE A( + | value TEXT AS kotlin.collections.List NOT NULL + |); + | + |CREATE TABLE B( + | value TEXT AS kotlin.collections.List NOT NULL + |); + | + |unionOfBoth: + |SELECT value, value + |FROM A + |UNION + |SELECT value, nullif(value, 1 == 1) + |FROM B; + """.trimMargin(), + temporaryFolder + ) + + val query = file.namedQueries.first() + assertThat(QueryInterfaceGenerator(query).kotlinImplementationSpec().toString()).isEqualTo( + """ + |public data class UnionOfBoth( + | public val value: kotlin.collections.List, + | public val value_: kotlin.collections.List? + |) { + | public override fun toString(): kotlin.String = ""${'"'} + | |UnionOfBoth [ + | | value: ${"$"}value + | | value_: ${"$"}value_ + | |] + | ""${'"'}.trimMargin() + |} + |""".trimMargin() + ) + } + @Test fun `null type uses the other column in a union`() { val file = FixtureCompiler.parseSql( """ diff --git a/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/SelectQueryTypeTest.kt b/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/SelectQueryTypeTest.kt index c714a5ab0b0..e000d5eb627 100644 --- a/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/SelectQueryTypeTest.kt +++ b/sqldelight-compiler/src/test/kotlin/com/squareup/sqldelight/core/queries/SelectQueryTypeTest.kt @@ -679,6 +679,61 @@ class SelectQueryTypeTest { ) } + @Test + fun `compatible java types from different columns checks for adapter equivalence`(dialect: DialectPreset) { + val file = FixtureCompiler.parseSql( + """ + |CREATE TABLE children( + | birthday ${dialect.textType} AS java.time.LocalDate NOT NULL + |); + | + |CREATE TABLE teenagers( + | birthday ${dialect.textType} AS java.time.LocalDate NOT NULL + |); + | + |CREATE TABLE adults( + | birthday ${dialect.textType} AS java.time.LocalDate + |); + | + |birthdays: + |SELECT birthday + |FROM children + |UNION + |SELECT birthday + |FROM teenagers + |UNION + |SELECT birthday + |FROM adults; + |""".trimMargin(), + tempFolder, dialectPreset = dialect + ) + + val query = file.namedQueries.first() + val generator = SelectQueryGenerator(query) + + assertThat(generator.customResultTypeFunction().toString()).isEqualTo( + """ + |public override fun birthdays(mapper: (birthday: java.time.LocalDate?) -> T): com.squareup.sqldelight.Query { + | kotlin.check(kotlin.collections.setOf(database.childrenAdapter.birthdayAdapter, database.teenagersAdapter.birthdayAdapter, database.adultsAdapter.birthdayAdapter).size == 1) { "Adapter types are expected to be identical." } + | return com.squareup.sqldelight.Query(${query.id}, birthdays, driver, "Test.sq", "birthdays", ""${'"'} + | |SELECT birthday + | |FROM children + | |UNION + | |SELECT birthday + | |FROM teenagers + | |UNION + | |SELECT birthday + | |FROM adults + | ""${'"'}.trimMargin()) { cursor -> + | mapper( + | cursor.getString(0)?.let { database.childrenAdapter.birthdayAdapter.decode(it) } + | ) + | } + |} + |""".trimMargin() + ) + } + @Test fun `proper exposure of month and year functions`(dialect: DialectPreset) { assumeTrue(dialect in listOf(DialectPreset.MYSQL))