3-2-3 DSL-自定义

29 阅读5分钟

Kotlin 自定义DSL深度解析

一、DSL基础概念与设计原则

1.1 DSL分类

// 内部DSL(Embedded DSL) - 在宿主语言中构建
class HtmlDsl {
    fun html(block: HtmlBuilder.() -> Unit): String {
        return HtmlBuilder().apply(block).build()
    }
}

// 外部DSL - 独立的语言,需要解析器
// 例如:SQL、正则表达式、Gradle脚本

1.2 DSL设计原则

// 1. 流畅性(Fluent API)
// 方法链式调用,返回this或新对象
class QueryBuilder {
    fun select(vararg columns: String) = this.apply { /* ... */ }
    fun from(table: String) = this.apply { /* ... */ }
    fun where(condition: String) = this.apply { /* ... */ }
}

// 2. 上下文(Context)
// 使用带接收者的lambda
class TableDsl {
    fun table(name: String, init: TableBuilder.() -> Unit) {
        TableBuilder(name).apply(init)
    }
}

// 3. 类型安全(Type Safety)
// 利用Kotlin类型系统防止错误
sealed class ColumnType {
    data object VARCHAR : ColumnType()
    data object INT : ColumnType()
    data class DECIMAL(val precision: Int, val scale: Int) : ColumnType()
}

二、核心构建技术

2.1 带接收者的Lambda表达式

// 基础示例
class Database {
    fun query(init: Query.() -> Unit): Query {
        return Query().apply(init)
    }
    
    inner class Query {
        var sql = ""
        
        infix fun select(columns: String) {
            sql = "SELECT $columns"
        }
        
        infix fun from(table: String) {
            sql += " FROM $table"
        }
    }
}

// 嵌套接收者
class Container {
    fun group(block: GroupScope.() -> Unit) {
        GroupScope().block()
    }
    
    class GroupScope {
        fun item(block: ItemScope.() -> Unit) {
            ItemScope().block()
        }
        
        class ItemScope {
            var name: String = ""
            fun name(value: String) { name = value }
        }
    }
}

// 使用
Container().group {
    item {
        name = "Item1"
    }
}

2.2 中缀函数与操作符重载

class RouteBuilder {
    private val routes = mutableListOf<String>()
    
    // 中缀函数
    infix fun String.path(handler: () -> Unit) {
        routes.add("$this -> ${handler.javaClass.simpleName}")
    }
    
    // 操作符重载
    operator fun String.invoke(block: RouteConfig.() -> Unit) {
        routes.add("$this configured")
    }
    
    class RouteConfig {
        fun method(method: String) { /* ... */ }
        fun header(key: String, value: String) { /* ... */ }
    }
}

// 使用
val router = RouteBuilder()
router.apply {
    "/api/users" path { println("get users") }
    "/api/posts" {
        method("POST")
        header("Content-Type", "application/json")
    }
}

三、类型安全的DSL构建

3.1 构建类型安全的SQL DSL

// 定义类型安全的列
interface Column<T> {
    val name: String
}

class StringColumn(override val name: String) : Column<String>
class IntColumn(override val name: String) : Column<Int>
class DateColumn(override val name: String) : Column<java.time.LocalDate>

// 表定义
class Table(val name: String) {
    val id = IntColumn("id")
    val name = StringColumn("name")
    val email = StringColumn("email")
    val createdAt = DateColumn("created_at")
}

// 查询构建器
class TypedQueryBuilder {
    private val selectColumns = mutableListOf<Column<*>>()
    private lateinit var fromTable: Table
    private val conditions = mutableListOf<String>()
    
    fun <T> select(vararg columns: Column<T>): TypedQueryBuilder {
        selectColumns.addAll(columns)
        return this
    }
    
    fun from(table: Table): TypedQueryBuilder {
        fromTable = table
        return this
    }
    
    fun where(condition: () -> String): TypedQueryBuilder {
        conditions.add(condition())
        return this
    }
    
    infix fun Column<*>.eq(value: Any): String = "$name = '$value'"
    infix fun Column<*>.like(value: String): String = "$name LIKE '$value'"
    infix fun Column<*>.gt(value: Any): String = "$name > $value"
    
    fun build(): String {
        val columns = selectColumns.joinToString(", ") { it.name }
        val whereClause = if (conditions.isNotEmpty()) {
            " WHERE ${conditions.joinToString(" AND ")}"
        } else ""
        return "SELECT $columns FROM ${fromTable.name}$whereClause"
    }
}

// 使用
val users = Table("users")
val query = TypedQueryBuilder()
    .select(users.id, users.name, users.email)
    .from(users)
    .where { users.name like "%john%" }
    .where { users.createdAt gt "2023-01-01" }
    .build()

3.2 构建类型安全的配置DSL

class ConfigurationDsl {
    private val config = mutableMapOf<String, Any>()
    
    fun server(block: ServerConfig.() -> Unit) {
        config["server"] = ServerConfig().apply(block)
    }
    
    fun database(block: DatabaseConfig.() -> Unit) {
        config["database"] = DatabaseConfig().apply(block)
    }
    
    class ServerConfig {
        var host: String = "localhost"
        var port: Int = 8080
        var ssl: Boolean = false
        
        fun ssl(block: SSLConfig.() -> Unit) {
            ssl = true
            config["ssl"] = SSLConfig().apply(block)
        }
        
        class SSLConfig {
            var keystore: String = ""
            var password: String = ""
        }
    }
    
    class DatabaseConfig {
        var url: String = ""
        var driver: String = ""
        var username: String = ""
        var password: String = ""
        
        fun pool(block: ConnectionPool.() -> Unit) {
            config["pool"] = ConnectionPool().apply(block)
        }
        
        class ConnectionPool {
            var maxSize: Int = 10
            var minIdle: Int = 2
            var timeout: Long = 30000
        }
    }
}

四、领域特定DSL实战

4.1 HTTP路由DSL

class HttpRouter {
    private val routes = mutableListOf<Route>()
    
    data class Route(
        val path: String,
        val method: HttpMethod,
        val handler: (HttpRequest) -> HttpResponse,
        val middlewares: List<Middleware> = emptyList()
    )
    
    sealed class HttpMethod {
        data object GET : HttpMethod()
        data object POST : HttpMethod()
        data object PUT : HttpMethod()
        data object DELETE : HttpMethod()
        data object PATCH : HttpMethod()
    }
    
    typealias Middleware = (HttpRequest, (HttpRequest) -> HttpResponse) -> HttpResponse
    data class HttpRequest(val path: String, val method: HttpMethod, val headers: Map<String, String>, val body: String)
    data class HttpResponse(val status: Int, val body: String, val headers: Map<String, String> = emptyMap())
    
    // DSL构建器
    fun route(path: String, block: RouteBuilder.() -> Unit) {
        val builder = RouteBuilder(path)
        builder.block()
        routes.add(builder.build())
    }
    
    inner class RouteBuilder(private val path: String) {
        private var method: HttpMethod = HttpMethod.GET
        private var handler: ((HttpRequest) -> HttpResponse)? = null
        private val middlewares = mutableListOf<Middleware>()
        
        infix fun HttpMethod.handle(handler: (HttpRequest) -> HttpResponse) {
            this@RouteBuilder.method = this
            this@RouteBuilder.handler = handler
        }
        
        fun middleware(middleware: Middleware) {
            middlewares.add(middleware)
        }
        
        fun build(): Route {
            requireNotNull(handler) { "Handler must be specified for route $path" }
            return Route(path, method, handler!!, middlewares)
        }
    }
    
    // DSL使用
    fun setupRoutes() {
        route("/api/users") {
            HttpMethod.GET handle { request ->
                HttpResponse(200, "User list")
            }
            
            HttpMethod.POST handle { request ->
                HttpResponse(201, "User created")
            }
            
            middleware { request, next ->
                println("Logging middleware")
                next(request)
            }
        }
    }
}

4.2 构建Gradle-like任务DSL

class TaskSystem {
    private val tasks = mutableMapOf<String, Task>()
    
    data class Task(
        val name: String,
        val dependsOn: Set<String> = emptySet(),
        val action: () -> Unit,
        val description: String = ""
    )
    
    fun tasks(block: TaskRegistry.() -> Unit) {
        TaskRegistry().apply(block)
    }
    
    inner class TaskRegistry {
        fun task(name: String, block: TaskBuilder.() -> Unit) {
            val builder = TaskBuilder(name)
            builder.block()
            tasks[name] = builder.build()
        }
    }
    
    inner class TaskBuilder(val name: String) {
        private var dependsOn = mutableSetOf<String>()
        private lateinit var action: () -> Unit
        private var description: String = ""
        
        fun dependsOn(vararg tasks: String) {
            dependsOn.addAll(tasks)
        }
        
        fun doLast(block: () -> Unit) {
            action = block
        }
        
        fun description(text: String) {
            description = text
        }
        
        fun build(): Task {
            return Task(name, dependsOn, action, description)
        }
    }
    
    // DSL扩展:任务之间的依赖关系
    operator fun String.invoke(block: () -> Unit) {
        tasks[this]?.action?.invoke() ?: error("Task $this not found")
    }
    
    infix fun String.dependsOn(vararg tasks: String) {
        val task = this@TaskSystem.tasks[this] ?: error("Task $this not found")
        val newTask = task.copy(dependsOn = task.dependsOn + tasks)
        this@TaskSystem.tasks[this] = newTask
    }
}

// 使用
val buildSystem = TaskSystem()
buildSystem.tasks {
    task("clean") {
        description = "Clean build directory"
        doLast {
            println("Cleaning...")
        }
    }
    
    task("compile") {
        dependsOn("clean")
        description = "Compile source code"
        doLast {
            println("Compiling...")
        }
    }
    
    task("test") {
        dependsOn("compile")
        description = "Run tests"
        doLast {
            println("Testing...")
        }
    }
    
    task("build") {
        dependsOn("test")
        description = "Build project"
        doLast {
            println("Building...")
        }
    }
}

// 添加额外依赖
"test" dependsOn "lint"

五、DSL优化与高级技巧

5.1 使用DSL作用域控制

class ScopedDSL {
    // 作用域接口
    interface Scope {
        fun allowedOperation()
    }
    
    class GlobalScope : Scope {
        override fun allowedOperation() {
            println("Global operation")
        }
        
        fun onlyInGlobal() {
            println("Only in global scope")
        }
    }
    
    class LocalScope : Scope {
        override fun allowedOperation() {
            println("Local operation")
        }
        
        fun onlyInLocal() {
            println("Only in local scope")
        }
    }
    
    // 带作用域限制的DSL
    fun <S : Scope> withScope(scope: S, block: S.() -> Unit) {
        scope.block()
    }
    
    // 作用域切换
    fun global(block: GlobalScope.() -> Unit): LocalScope {
        GlobalScope().block()
        return LocalScope()
    }
    
    fun GlobalScope.local(block: LocalScope.() -> Unit) {
        LocalScope().block()
    }
}

// 使用
val dsl = ScopedDSL()
dsl.withScope(dsl.GlobalScope()) {
    allowedOperation()
    onlyInGlobal()
    local {
        allowedOperation()
        onlyInLocal()
    }
}

5.2 构建响应式流DSL

class StreamDSL<T> {
    private val operations = mutableListOf<(T) -> T>()
    private val filters = mutableListOf<(T) -> Boolean>()
    
    // 转换操作
    fun map(transform: (T) -> T): StreamDSL<T> {
        operations.add(transform)
        return this
    }
    
    // 过滤操作
    fun filter(predicate: (T) -> Boolean): StreamDSL<T> {
        filters.add(predicate)
        return this
    }
    
    // 终止操作
    fun collect(): List<T> {
        // 实际实现会从数据源获取数据
        return emptyList()
    }
    
    // DSL构建器
    companion object {
        fun <T> stream(block: StreamBuilder<T>.() -> Unit): StreamDSL<T> {
            return StreamBuilder<T>().apply(block).build()
        }
    }
    
    class StreamBuilder<T> {
        private val dsl = StreamDSL<T>()
        
        infix fun T.map(transform: (T) -> T) {
            dsl.map(transform)
        }
        
        infix fun T.filter(predicate: (T) -> Boolean) {
            dsl.filter(predicate)
        }
        
        fun build(): StreamDSL<T> = dsl
    }
}

// 使用
val stream = StreamDSL.stream<Int> {
    1 map { it * 2 } filter { it > 0 }
    2 map { it * 3 } filter { it < 10 }
}

六、DSL测试与调试

6.1 DSL单元测试

class DslTest {
    @Test
    fun `test html dsl`() {
        val html = html {
            body {
                div {
                    +"Hello"
                    +"World"
                }
            }
        }
        
        assertTrue(html.contains("<div>"))
        assertTrue(html.contains("Hello"))
        assertTrue(html.contains("World"))
    }
    
    @Test
    fun `test type safety`() {
        val query = assertDoesNotThrow {
            TypedQueryBuilder()
                .select(users.id, users.name)
                .from(users)
                .where { users.id eq 1 }
                .build()
        }
        
        assertTrue(query.contains("SELECT"))
        assertTrue(query.contains("FROM"))
    }
}

6.2 DSL调试技巧

class DebuggableDSL {
    // 1. 添加调试信息
    inline fun <reified T : Any> dsl(block: T.() -> Unit, debug: Boolean = false): T {
        val instance = T::class.java.newInstance()
        if (debug) {
            println("DSL开始: ${T::class.simpleName}")
            instance.block()
            println("DSL结束: ${T::class.simpleName}")
        } else {
            instance.block()
        }
        return instance
    }
    
    // 2. 跟踪调用链
    class TracedBuilder {
        private val callStack = mutableListOf<String>()
        
        fun method(name: String, block: TracedBuilder.() -> Unit) {
            callStack.add("method: $name")
            println("进入: $name")
            block()
            println("退出: $name")
            callStack.removeLast()
        }
        
        fun printTrace() {
            println("调用栈:")
            callStack.forEachIndexed { index, call ->
                println("  ${" ".repeat(index * 2)}$call")
            }
        }
    }
}

七、性能优化

7.1 避免DSL性能陷阱

class OptimizedDSL {
    // 1. 避免不必要的对象创建
    private val sharedBuilder = StringBuilder()
    
    fun buildHtml(block: HtmlBuilder.() -> Unit): String {
        val builder = HtmlBuilder(sharedBuilder.clear())
        builder.block()
        return builder.toString()
    }
    
    // 2. 使用内联函数减少开销
    inline fun inlineDsl(crossinline block: () -> Unit) {
        val start = System.nanoTime()
        block()
        val duration = System.nanoTime() - start
        if (duration > 1_000_000) { // 超过1ms
            println("DSL执行时间: ${duration}ns")
        }
    }
    
    // 3. 缓存频繁使用的DSL组件
    private val cachedComponents = mutableMapOf<String, Any>()
    
    fun <T : Any> component(key: String, factory: () -> T): T {
        return cachedComponents.getOrPut(key) { factory() } as T
    }
}

八、实战:完整配置DSL示例

// 完整的企业级配置DSL
class AppConfigDsl {
    data class Config(
        val server: ServerConfig,
        val database: DatabaseConfig,
        val security: SecurityConfig,
        val logging: LoggingConfig
    )
    
    data class ServerConfig(
        val host: String,
        val port: Int,
        val ssl: SslConfig?
    )
    
    data class SslConfig(
        val enabled: Boolean,
        val keystore: String,
        val keystorePassword: String,
        val truststore: String? = null
    )
    
    data class DatabaseConfig(
        val url: String,
        val driver: String,
        val username: String,
        val password: String,
        val pool: ConnectionPoolConfig
    )
    
    data class ConnectionPoolConfig(
        val maxSize: Int,
        val minIdle: Int,
        val connectionTimeout: Long,
        val validationTimeout: Long
    )
    
    data class SecurityConfig(
        val cors: CorsConfig,
        val jwt: JwtConfig?,
        val rateLimit: RateLimitConfig
    )
    
    data class CorsConfig(
        val allowedOrigins: List<String>,
        val allowedMethods: List<String>,
        val allowedHeaders: List<String>
    )
    
    data class JwtConfig(
        val secret: String,
        val issuer: String,
        val audience: String,
        val expiration: Long
    )
    
    data class RateLimitConfig(
        val requestsPerMinute: Int,
        val burstSize: Int
    )
    
    data class LoggingConfig(
        val level: String,
        val file: FileConfig?,
        val format: String
    )
    
    data class FileConfig(
        val path: String,
        val maxSize: String,
        val maxFiles: Int
    )
    
    // DSL构建器
    class Builder {
        private lateinit var server: ServerBuilder
        private lateinit var database: DatabaseBuilder
        private lateinit var security: SecurityBuilder
        private lateinit var logging: LoggingBuilder
        
        fun server(block: ServerBuilder.() -> Unit) {
            server = ServerBuilder().apply(block)
        }
        
        fun database(block: DatabaseBuilder.() -> Unit) {
            database = DatabaseBuilder().apply(block)
        }
        
        fun security(block: SecurityBuilder.() -> Unit) {
            security = SecurityBuilder().apply(block)
        }
        
        fun logging(block: LoggingBuilder.() -> Unit) {
            logging = LoggingBuilder().apply(block)
        }
        
        fun build(): Config {
            return Config(
                server.build(),
                database.build(),
                security.build(),
                logging.build()
            )
        }
    }
    
    class ServerBuilder {
        var host: String = "localhost"
        var port: Int = 8080
        var ssl: SslBuilder? = null
        
        fun ssl(block: SslBuilder.() -> Unit) {
            ssl = SslBuilder().apply(block)
        }
        
        fun build(): ServerConfig {
            return ServerConfig(host, port, ssl?.build())
        }
    }
    
    class SslBuilder {
        var enabled: Boolean = true
        var keystore: String = ""
        var keystorePassword: String = ""
        var truststore: String? = null
        
        fun build(): SslConfig {
            return SslConfig(enabled, keystore, keystorePassword, truststore)
        }
    }
    
    class DatabaseBuilder {
        var url: String = ""
        var driver: String = ""
        var username: String = ""
        var password: String = ""
        var pool: ConnectionPoolBuilder = ConnectionPoolBuilder()
        
        fun pool(block: ConnectionPoolBuilder.() -> Unit) {
            pool = ConnectionPoolBuilder().apply(block)
        }
        
        fun build(): DatabaseConfig {
            return DatabaseConfig(url, driver, username, password, pool.build())
        }
    }
    
    class ConnectionPoolBuilder {
        var maxSize: Int = 10
        var minIdle: Int = 2
        var connectionTimeout: Long = 30000
        var validationTimeout: Long = 5000
        
        fun build(): ConnectionPoolConfig {
            return ConnectionPoolConfig(maxSize, minIdle, connectionTimeout, validationTimeout)
        }
    }
    
    // ... SecurityBuilder, LoggingBuilder 类似实现
    
    companion object {
        fun config(block: Builder.() -> Unit): Config {
            return Builder().apply(block).build()
        }
    }
}

// 使用示例
val config = AppConfigDsl.config {
    server {
        host = "api.example.com"
        port = 443
        ssl {
            keystore = "/path/to/keystore.jks"
            keystorePassword = "secret"
        }
    }
    
    database {
        url = "jdbc:mysql://localhost:3306/mydb"
        driver = "com.mysql.cj.jdbc.Driver"
        username = "root"
        password = "password"
        pool {
            maxSize = 20
            minIdle = 5
            connectionTimeout = 60000
        }
    }
    
    security {
        cors {
            allowedOrigins = listOf("https://example.com")
            allowedMethods = listOf("GET", "POST", "PUT", "DELETE")
        }
        
        jwt {
            secret = "jwt-secret-key"
            issuer = "my-app"
            expiration = 3600
        }
    }
}

最佳实践总结

  1. 保持简洁:DSL应该让复杂任务变简单,而不是相反
  2. 类型安全:充分利用Kotlin类型系统,在编译时捕获错误
  3. 流畅易读:DSL代码应该读起来像自然语言
  4. 分层设计:复杂的DSL应该分层,每层关注特定职责
  5. 良好文档:为DSL提供清晰的文档和使用示例
  6. 性能意识:注意DSL的性能开销,特别是频繁调用的部分
  7. 测试覆盖:为DSL编写全面的测试,确保稳定性和正确性

通过合理使用Kotlin的DSL特性,可以创建出表达力强、类型安全且易于使用的领域特定语言,极大提升开发效率和代码质量。