Compose注入ViewModel源码分析并给构造方法传参

52 阅读6分钟

我们在Compose中使用viewModel()或hiltViewModel()创建ViewModel时,如何向ViewModel的构造方法中添加自定义参数呢?最常见的是用ViewModel来获取数据,会在ViewModel的构造方法中注入一个Repository用于请求数据,由于Repository管理全局数据,可以构建一个单例对象,在ViewModel的构造方法中直接使用。若想动态添加自定义参数,就得用其它方式了。

viewModel源码分析

hiltViewModel最终是调用viewModel来创建ViewModel对象,这里我们先分析viewModel的源码

@Composable
inline fun <reified VM : ViewModel> viewModel(
    viewModelStoreOwner: ViewModelStoreOwner =
        checkNotNull(LocalViewModelStoreOwner.current) {
            "No ViewModelStoreOwner was provided via LocalViewModelStoreOwner"
        },
    key: String? = null,
    factory: ViewModelProvider.Factory? = null,
    extras: CreationExtras =
        if (viewModelStoreOwner is HasDefaultViewModelProviderFactory) {
            viewModelStoreOwner.defaultViewModelCreationExtras
        } else {
            CreationExtras.Empty
        }
): VM = viewModel(VM::class, viewModelStoreOwner, key, factory, extras)

@Composable
public fun <VM : ViewModel> viewModel(
  modelClass: KClass<VM>,
  viewModelStoreOwner: ViewModelStoreOwner =
    checkNotNull(LocalViewModelStoreOwner.current) {
      "No ViewModelStoreOwner was provided via LocalViewModelStoreOwner"
    },
  key: String? = null,
  factory: ViewModelProvider.Factory? = null,
  extras: CreationExtras =
    if (viewModelStoreOwner is HasDefaultViewModelProviderFactory) {
      viewModelStoreOwner.defaultViewModelCreationExtras
    } else {
      CreationExtras.Empty
    }
): VM = viewModelStoreOwner.get(modelClass, key, factory, extras)

在Compose中调用viewModel方法来创建ViewModel,4个参数都是可选的,第1个参数viewModelStoreOwner获取当前环境中的ViewModelStoreOwner,默认由所在的Activity或Fragment来提供,我们使用默认值就可以了,第2个参数key用于ViewModelStore存储ViewModel对象,当不传时也有一个默认值。第3个参数factory用于自定义创建ViewModel,第4个参数用于创建ViewModel时添加的附加参数。接下来重点分析第3个参数和第4个参数的用法。

extras参数源码分析

当viewModelStoreOwner取默认值时,通过LocalViewModelStoreOwner.current获取到ComponentActivity中提供的ViewModelStoreOwner,ComponentActivity实现了HasDefaultViewModelProviderFactory接口,并实现了默认的factory和extras,创建extras源码如下

override val defaultViewModelCreationExtras: CreationExtras
  get() {
    val extras = MutableCreationExtras()
    if (application != null) {
      extras[APPLICATION_KEY] = application
    }
    extras[SAVED_STATE_REGISTRY_OWNER_KEY] = this
    extras[VIEW_MODEL_STORE_OWNER_KEY] = this
    val intentExtras = intent?.extras
    if (intentExtras != null) {
      extras[DEFAULT_ARGS_KEY] = intentExtras
    }
    return extras
  }

CreationExtras添加了APPLICATION_KEY、SAVED_STATE_REGISTRY_OWNER_KEY和VIEW_MODEL_STORE_OWNER_KEY 3个参数,如果intent的extras有值,则会把这部分数据也传给ViewModel。

ViewModel的创建

extras的创建我们知道了,接下来看下创建ViewModel时如何使用这些数据。在viewModel方法中,最终调用了viewModelStoreOwner.get来创建ViewModel,get源码如下

internal fun <VM : ViewModel> ViewModelStoreOwner.get(
    modelClass: KClass<VM>,
    key: String? = null,
    factory: ViewModelProvider.Factory? = null,
    extras: CreationExtras =
        if (this is HasDefaultViewModelProviderFactory) {
            this.defaultViewModelCreationExtras
        } else {
            CreationExtras.Empty
        }
): VM {
    val provider =
        if (factory != null) {
            ViewModelProvider.create(this.viewModelStore, factory, extras)
        } else if (this is HasDefaultViewModelProviderFactory) {
            ViewModelProvider.create(
                this.viewModelStore,
                this.defaultViewModelProviderFactory,
                extras
            )
        } else {
            ViewModelProvider.create(this)
        }
    return if (key != null) {
        provider[key, modelClass]
    } else {
        provider[modelClass]
    }
}

这里先创建一个Provider,Provider内部持有ViewModelStore,当ViewModel创建过,则Provider直接返回。由前面可知当前ViewModelStoreOwner实现了HasDefaultViewModelProviderFactory接口,我们继续看provider的get方法是如何返回的

actual open class ViewModelProvider
private constructor(private val impl: ViewModelProviderImpl) {
  @MainThread
  actual operator fun <T : ViewModel> get(modelClass: KClass<T>): T =
    impl.getViewModel(modelClass)
}

internal class ViewModelProviderImpl(
  private val store: ViewModelStore,
  private val factory: ViewModelProvider.Factory,
  private val defaultExtras: CreationExtras,
) {
  private val lock = SynchronizedObject()

  @Suppress("UNCHECKED_CAST")
  fun <T : ViewModel> getViewModel(
    modelClass: KClass<T>,
    key: String = ViewModelProviders.getDefaultKey(modelClass),
  ): T {
    return synchronized(lock) {
      val viewModel = store[key]
      if (modelClass.isInstance(viewModel)) {
        if (factory is ViewModelProvider.OnRequeryFactory) {
          factory.onRequery(viewModel!!)
        }
        return@synchronized viewModel as T
      }

      val modelExtras = MutableCreationExtras(defaultExtras)
      modelExtras[ViewModelProvider.VIEW_MODEL_KEY] = key

      return@synchronized createViewModel(factory, modelClass, modelExtras).also { vm ->
        store.put(key, vm)
      }
    }
  }
}

internal actual fun <VM : ViewModel> createViewModel(
  factory: ViewModelProvider.Factory,
  modelClass: KClass<VM>,
  extras: CreationExtras
): VM {
  return try {
    factory.create(modelClass, extras)
  } catch (e: AbstractMethodError) {
    try {
      factory.create(modelClass.java, extras)
    } catch (e: AbstractMethodError) {
      factory.create(modelClass.java)
    }
  }
}

ViewModelProvider内部持有ViewModelProviderImpl对象,get方法由ViewModelProviderImpl的getViewModel方法实现,若store中已经存在该key对应的ViewModel,则直接返回,没有的话,由factory调用create方法创建,而这里的factory是由ComponentActivity提供的默认Factory。

open class ComponentActivity(){
  override val defaultViewModelProviderFactory: ViewModelProvider.Factory by lazy {
    SavedStateViewModelFactory(application, this, if (intent != null) intent.extras else null)
  }
}

该默认Factory是SavedStateViewModelFactory对象,接下来分析SavedStateViewModelFactory的create方法

actual class SavedStateViewModelFactory :
ViewModelProvider.OnRequeryFactory, ViewModelProvider.Factory {
  private val factory: ViewModelProvider.Factory
  actual constructor() {
    factory = ViewModelProvider.AndroidViewModelFactory()
  }
  override fun <T : ViewModel> create(modelClass: Class<T>, extras: CreationExtras): T {
    val key =
      extras[ViewModelProvider.VIEW_MODEL_KEY]
        ?: throw IllegalStateException(
          "VIEW_MODEL_KEY must always be provided by ViewModelProvider"
        )

    return if (
      extras[SAVED_STATE_REGISTRY_OWNER_KEY] != null &&
      extras[VIEW_MODEL_STORE_OWNER_KEY] != null
    ) {
      val application = extras[ViewModelProvider.AndroidViewModelFactory.APPLICATION_KEY]
      val isAndroidViewModel = AndroidViewModel::class.java.isAssignableFrom(modelClass)
      val constructor: Constructor<T>? =
        if (isAndroidViewModel && application != null) {
          findMatchingConstructor(modelClass, ANDROID_VIEWMODEL_SIGNATURE)
        } else {
          findMatchingConstructor(modelClass, VIEWMODEL_SIGNATURE)
        }
      // doesn't need SavedStateHandle
      if (constructor == null) {
        return factory.create(modelClass, extras)
      }
      val viewModel =
        if (isAndroidViewModel && application != null) {
          newInstance(
            modelClass,
            constructor,
            application,
            extras.createSavedStateHandle()
          )
        } else {
          newInstance(modelClass, constructor, extras.createSavedStateHandle())
        }
      viewModel
    } else {
      val viewModel =
        if (lifecycle != null) {
          create(key, modelClass)
        } else {
          throw IllegalStateException(
            "SAVED_STATE_REGISTRY_OWNER_KEY and" +
              "VIEW_MODEL_STORE_OWNER_KEY must be provided in the creation extras to" +
              "successfully create a ViewModel."
          )
        }
      viewModel
    }
  }
}

create方法传入了我们提前提到的CreationExtras参数,在这里将CreationExtras对象中的不同数据取出来了,在newInstance方法创建ViewModel时,extras创建了一个SavedStateHandle,继续看createSavedStateHandle的源码

fun CreationExtras.createSavedStateHandle(): SavedStateHandle {
  val savedStateRegistryOwner =
    this[SAVED_STATE_REGISTRY_OWNER_KEY]
      ?: throw IllegalArgumentException(
        "CreationExtras must have a value by `SAVED_STATE_REGISTRY_OWNER_KEY`"
      )
  val viewModelStateRegistryOwner =
    this[VIEW_MODEL_STORE_OWNER_KEY]
      ?: throw IllegalArgumentException(
        "CreationExtras must have a value by `VIEW_MODEL_STORE_OWNER_KEY`"
      )
  val defaultArgs = this[DEFAULT_ARGS_KEY]
  val key =
    this[VIEW_MODEL_KEY]
      ?: throw IllegalArgumentException(
          "CreationExtras must have a value by `VIEW_MODEL_KEY`"
      )
  return createSavedStateHandle(
    savedStateRegistryOwner,
    viewModelStateRegistryOwner,
    key,
    defaultArgs
  )
}

前面是参数的校验,最后取了一个参数DEFAULT_ARGS_KEY,这个值的类型是Bundle类型。至此,我们就可以通过DEFAULT_ARGS_KEY这个key,将参数通过键值对的形式存起来,在ViewModel中就可以取出来了。

向ViewModel传递键值对参数

在调用viewModel时,先获取extras参数,再将我们的自定义参数骑过DEFAULT_ARGS_KEY存起来,在ViewModel中通过key就可以取到值了,比如我们需要动态传递一个name和age参数,示例如下

@Composable
fun mainPage(name: String = "jinhongke", age: Int = 18) {
  val viewModel: MainViewModel = viewModel<MainViewModel>(extras = remember {
    ((owner as? HasDefaultViewModelProviderFactory)?.defaultViewModelCreationExtras?.let {
      it as? MutableCreationExtras ?: MutableCreationExtras(it)
    } ?: MutableCreationExtras()).also {
      it[DEFAULT_ARGS_KEY] = (it[DEFAULT_ARGS_KEY] ?: bundleOf()).apply {
        putString("name", name)
        putInt("age", age)
      }
    }
  })
  viewModel.print()
}

class CameraViewModel(state: SavedStateHandle) : AndroidViewModel(application) {
  private val name = state["name"] ?: ""
  private val age = state["age"] ?: 0

  fun print() {
    println("$name is the $age years old.")
  }
}

向ViewModel传递对象参数

除了通过extras传递参数外,我们也可以通过factory自己来创建ViewModel,并传递相应的对象参数。前面在SavedStateViewModelFactory类中,我们注意到有一个默认的factory是ViewModelProvider.AndroidViewModelFactory对象,这里我们也可以效仿实现一个AndroidViewModelFactory子类,并添加我们需要的自定义参数,示例如下

data class User(val name: String = "jinhongke", val age: Int = 18)

class MainViewModelFactory(private val user: User) : ViewModelProvider.AndroidViewModelFactory() {

  override fun <T : ViewModel> create(modelClass: Class<T>, extras: CreationExtras): T {
    return MainViewModel(application, param) as T
  }
}

class MainViewModel(application: Application, user: User) : AndroidViewModel(application) {
  fun showToast(){
    Toast.makeText(application, "${user.name} is the ${user.age} years old.", Toast.LENGTH_SHORT).show()
  }
}

@Composable
fun mainPage(name: String = "jinhongke", age: Int = 18) {
  val viewModel = viewModel<MainViewModel>(factory = PictureViewModelFactory(User(name, age)))
  LaunchedEffect(Unit) {
    viewModel.showToast()
  }
}

使用hiltViewModel动态注入参数

前面我们通过extras以及factory的方式,虽然实现了我们想要的效果,但可以发现都存在一定的问题,通过extras添加附加参数的方式,默认使用了activity或fragment中提供的HasDefaultViewModelProviderFactory对象,兼容性不是很好,哪天版本升级源码改了,我们还得重新适配。factory虽然不存在这种问题,但通用性不好,如果每个ViewModel都有自己的参数,要么耦合到一个Factory,要么每个ViewModel都得创建一个自己的Factory,重复代码太多。

而Hilt也注意到这种问题,并提供了Factory注解,我们只需要定义好冷接口,告诉Hilt,我们的Factory需要哪些自定义参数,它来帮我们生成Factory。示例如下

data class User(val name: String, val age: Int)

class Repositories @Inject constructor() {
  fun getUserInfo(name: String, age: Int): String {
    return "${user.name} is the ${user.age} years old."
  }
}

@HiltViewModel(assistedFactory = MainViewModel.Factory::class)
class MainViewModel @AssistedInject constructor(
  private val repo: Repositories,
  @Assisted val user: User
) : ViewModel() {
  fun loadUserInfo(){
    val result = repo.getUserInfo(user.name, user.age)
    println(result)
  }
  @AssistedFactory
  interface Factory {
    fun create(user: User): MainViewModel
  }
}

@Composable
fun mainPage(viewModel: MainViewModel = hiltViewModel<MainViewModel, MainViewModel.Factory>(){
  it.create(User("jinhongke", 18))
}) {
  LaunchedEffect(Unit) {
    viewModel.loadUserInfo()
  }
}

Repositories是全局注入的对象,创建ViewModel时我们不用处理,由Hilt自动帮我们传入,User是由Hilt提供的Factory动态注入的参数,实现创建ViewModel时,也需要传入我们需要的动态参数。