From 521e580b3f00ce498bd2c7c4797d0807b5c2cf4a Mon Sep 17 00:00:00 2001 From: Alec Strong Date: Sat, 16 Apr 2022 10:30:02 -0400 Subject: [PATCH] Common tables do not generate data classes so dont return them --- .../core/compiler/model/SelectQueryable.kt | 4 +- .../core/queries/InterfaceGeneration.kt | 115 ++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt index cd029c4a583..ff4690fd965 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt @@ -6,6 +6,7 @@ import com.alecstrong.sql.psi.core.psi.QueryElement.QueryColumn import com.alecstrong.sql.psi.core.psi.Queryable import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement import com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt +import com.alecstrong.sql.psi.core.psi.SqlCteTableName import com.intellij.psi.util.PsiTreeUtil class SelectQueryable( @@ -44,7 +45,8 @@ class SelectQueryable( } return@lazy select.tablesAvailable(select).firstOrNull { - it.query.columns.flattenCompounded() == pureColumns + (it.tableName.parent !is SqlCteTableName) && + it.query.columns.flattenCompounded() == pureColumns }?.tableName } } diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt index c9d888a88bc..d820036fc71 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt @@ -915,6 +915,121 @@ class InterfaceGeneration { ) } + @Test fun `value types correctly generated`() { + val result = FixtureCompiler.compileSql( + """ + |CREATE TABLE item ( + | id INTEGER PRIMARY KEY AUTOINCREMENT, + | parent_id INTEGER, + | children INTEGER NOT NULL + |); + | + |recursiveQuery: + |WITH RECURSIVE + |descendants AS ( + | SELECT id, parent_id + | FROM item + | WHERE item.id = :id + | UNION ALL + | SELECT item.id, item.parent_id + | FROM item, descendants + | WHERE item.id = descendants.parent_id + |) + |SELECT descendants.id, descendants.parent_id + |FROM descendants; + |""".trimMargin(), + temporaryFolder, fileName = "Recursive.sq" + ) + + val query = result.compiledFile.namedQueries[0] + + assertThat(result.errors).isEmpty() + val generatedInterface = result.compilerOutput.get(File(result.outputDirectory, "com/example/RecursiveQuery.kt")) + assertThat(generatedInterface).isNotNull() + assertThat(generatedInterface.toString()).isEqualTo( + """ + |package com.example + | + |import kotlin.Long + | + |public data class RecursiveQuery( + | public val id: Long, + | public val parent_id: Long?, + |) + |""".trimMargin() + ) + + val generatedQueries = result.compilerOutput.get(File(result.outputDirectory, "com/example/RecursiveQueries.kt")) + assertThat(generatedQueries).isNotNull() + assertThat(generatedQueries.toString()).isEqualTo( + """ + |package com.example + | + |import app.cash.sqldelight.Query + |import app.cash.sqldelight.TransacterImpl + |import app.cash.sqldelight.db.SqlCursor + |import app.cash.sqldelight.db.SqlDriver + |import kotlin.Any + |import kotlin.Long + |import kotlin.String + |import kotlin.Unit + | + |public class RecursiveQueries( + | private val driver: SqlDriver, + |) : TransacterImpl(driver) { + | public fun recursiveQuery(id: Long, mapper: (id: Long, parent_id: Long?) -> T): Query + | = RecursiveQueryQuery(id) { cursor -> + | mapper( + | cursor.getLong(0)!!, + | cursor.getLong(1) + | ) + | } + | + | public fun recursiveQuery(id: Long): Query = recursiveQuery(id) { id_, + | parent_id -> + | RecursiveQuery( + | id_, + | parent_id + | ) + | } + | + | private inner class RecursiveQueryQuery( + | public val id: Long, + | mapper: (SqlCursor) -> T, + | ) : Query(mapper) { + | public override fun addListener(listener: Query.Listener): Unit { + | driver.addListener(listener, arrayOf("item")) + | } + | + | public override fun removeListener(listener: Query.Listener): Unit { + | driver.removeListener(listener, arrayOf("item")) + | } + | + | public override fun execute(mapper: (SqlCursor) -> R): R = driver.executeQuery(${query.id}, + | ""${'"'} + | |WITH RECURSIVE + | |descendants AS ( + | | SELECT id, parent_id + | | FROM item + | | WHERE item.id = ? + | | UNION ALL + | | SELECT item.id, item.parent_id + | | FROM item, descendants + | | WHERE item.id = descendants.parent_id + | |) + | |SELECT descendants.id, descendants.parent_id + | |FROM descendants + | ""${'"'}.trimMargin(), mapper, 1) { + | bindLong(1, id) + | } + | + | public override fun toString(): String = "Recursive.sq:recursiveQuery" + | } + |} + |""".trimMargin() + ) + } + private fun checkFixtureCompiles(fixtureRoot: String) { val result = FixtureCompiler.compileFixture( fixtureRoot = "src/test/query-interface-fixtures/$fixtureRoot",