TF-Operator源码分析

2,663 阅读5分钟

调用流程

虽然KubeFlow提供了一大堆组件,涵盖了机器学习的方方面面,但模型训练肯定是KubeFlow最重要的功能。 KubeFlow针对各种各样的机器学习框架提供了训练的能力。方式是定义了各种各样的Operator,这些Operator的本质,是K8SCRD
一句话,TF-Operator就是开源社区基于K8S提供的扩展API,提供了TensorFlow的训练能力,从名字也能看出来,这个实现是类似Job的一种方式。
TF-Operator的代码不太多,但是由于用了大量的K8SAPI,结构有点复杂,我们只把重要的地方摘出来。最重要的是下面这几个源代码文件(后附源码分析)。

TF-Operator流程图

源码分析

为了帮助我们更加深刻的理解Kubeflow@TFJob的工作流程和实现机制,下面将TF-Operator重点代码拿出来一起过一遍。
pkg/controller.v1/tensorflow/controller.go:
NewTFController返回一个新的TFJob控制器:

func NewTFController (...) *TFController {
    ... ...
    // 设置同步处理程序。
    tc.syncHandler = tc.syncTFJob
    ... ...
    return tc
}

processNextWorkItem将从WorkQueue中读取单个工作项,并尝试通过调用syncHandler来处理它:

func (tc *TFController) processNextWorkItem() bool {
	obj, quit := tc.WorkQueue.Get()
	... ...
	// 同步TFJob以将实际状态匹配到所需的状态。
	forget, err := tc.syncHandler(key=obj.(string))
	if err == nil {
		if forget {
			tc.WorkQueue.Forget(key)
		}
	}
}

如果tfjob的期望值已经实现,那么syncTFJob就会用给定的key来同步tfjob,这意味着它不希望更多的 pod/service被创建或删除:

// 这个函数不能与同一个key同时调用
func (tc *TFController) syncTFJob(key string) (bool, error) {
	... ...
	sharedTFJob, err := tc.getTFJobFromName(namespace, name)
	
	tfjob := sharedTFJob.DeepCopy()

	// 为新tfjob设置默认值。
	scheme.Scheme.Default(tfjob)

	if tfjobNeedsSync && tfjob.DeletionTimestamp == nil {
        // 调用reconcileTFJobs来启动TFJobs
		reconcileTFJobsErr = tc.reconcileTFJobs(tfjob)
	}
    ... ...
}

pkg/controller.v1/tensorflow/pod.go:
reconcileTFJobs检查并更新每个给定TFReplicaSpecreplicas

// 如果在创建/删除 pods/services时发生错误,它将请求tfjob。 
func (tc *TFController) reconcileTFJobs(tfjob *tfv1.TFJob) error {
    ... ...
	// 如果TFJob terminated,则delete所有pod和service。
	if isSucceeded(tfjob.Status) || isFailed(tfjob.Status) {
		if err := tc.deletePodsAndServices(tfjob, pods); err != nil {
			return err
		}
		if err := tc.cleanupTFJob(tfjob); err != nil {
			return err
		}
		if tc.Config.EnableGangScheduling {
			if err := tc.DeletePodGroup(tfjob); err != nil {
				return err
			}
		}
        ... ...
	}

	// 检索以前的重试次数
	previousRetry := tc.WorkQueue.NumRequeues(tfjobKey)

	if tfJobExceedsLimit {
		// 如果TFJob超过了backofflimit或超过了active deadline,删除所有pod和service,然后将状态设置为failed(代码同上)
		... ...
		// 遍历配置文件的TFReplicaSpecs部分,分别为不同类型的节点启动相应的Pod。
        // 在启动Pod之后,还要为其启动一个Service。
		for rtype, spec := range tfjob.Spec.TFReplicaSpecs {
			err = tc.reconcilePods(tfjob, pods, rtype, spec, replicasStatus)
			... ...
			err = tc.reconcileServices(tfjob, services, rtype, spec)
            ... ...
		}
	}	
}

reconcilePods为每个给定的TFReplicaSpec检查和更新pod

// 如果在创建/删除pod时发生错误,它将请求tfjob。
func (tc *TFController) reconcilePods(...) error {	
	... ...
	// 获取rtype类型的所有pod。
	pods, err := tc.FilterPodsForReplicaType(pods, rt)
    ... ...
	podSlices, podsToBeRemoved := tc.GetPodSlices(pods, replicas, logger)

	// 缩减
	if tfjob.Spec.EnableDynamicWorker && len(podsToBeRemoved) > 0 {
		// 目前只允许缩减workers
		if rtype == tfv1.TFReplicaTypeWorker {
			for _, pod := range podsToBeRemoved {
				err := tc.PodControl.DeletePod(tfjob.Namespace, pod.Name, tfjob)
			}
		} 
	}

	for index, podSlice := range podSlices {
		if len(podSlice) == 0 {
			// 如果master pod存在,选择master pod
			// 如果没有master,第一个worker pod被选为master。
			if ContainChieforMasterSpec(tfjob) {
				if tfv1.IsChieforMaster(rtype) {
					masterRole = true
				}
			} else {
				if tfv1.IsWorker(rtype) && (index == 0) {
					masterRole = true
				}
			}
            // 调用createNewPod创建Pod
			err = tc.createNewPod(tfjob, rt, strconv.Itoa(index), spec, masterRole)
		} 
        ... ...
	}

	return tc.updateStatusSingle(tfjob, rtype, replicas, restart, worker0Completed)
}

createNewPod为给定的indextype创建一个新的pod

func (tc *TFController) createNewPod(tfjob *tfv1.TFJob, rt, index string, spec *common.ReplicaSpec, masterRole bool) error {
	
	expectationPodsKey := jobcontroller.GenExpectationPodsKey(tfjobKey, rt)
	err = tc.Expectations.ExpectCreations(expectationPodsKey, 1)
	
	// 创建 OwnerReference.
	controllerRef := tc.GenOwnerReference(tfjob)

	podTemplate := spec.Template.DeepCopy()
	... ...
    // 生成集群的配置信息,这里最关键,看一下实现
	if err := setClusterSpec(podTemplate, tfjob, rt, index); err != nil {
		return err
	}
	... ...
    // 使用上面的配置信息,真正启动Pod的创建
	err = tc.PodControl.CreatePodsWithControllerRef(tfjob.Namespace, podTemplate, tfjob, controllerRef)
}

setClusterSpec为给定的podTemplateSpec生成并设置TF_CONFIG

func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt, index string) error {
    ... ...
	// 生成TF_CONFIG JSON字符串。
	tfConfigStr, err := genTFConfigJSONStr(tfjob, rt, index)
	... ...
	// 将TF_CONFIG环境变量添加到pod中的tensorflow容器中。
	for i := range podTemplateSpec.Spec.Containers {
		if podTemplateSpec.Spec.Containers[i].Name == tfv1.DefaultContainerName {
			... ...
			podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{
				Name:  tfConfig,
				Value: tfConfigStr,
			})
			break
		}
	}
}

pkg/controller.v1/tensorflow/tensorflow.go:
genTFConfig将生成环境变量TF_CONFIG

{
    "cluster": {
        "ps": ["ps1:2222", "ps2:2222"],
        "worker": ["worker1:2222", "worker2:2222", "worker3:2222"]
    },
    "task": {
        "type": "ps",
        "index": 1
    },
}

主要代码如下:

func genTFConfigJSONStr(tfjob *tfv1.TFJob, rtype, index string) (string, error) {
	// 配置TFCONFIG环境变量。
	cluster, err := genClusterSpec(tfjob)
    ... ...
    // 组装形成TF_CONFIG
	if tfjob.Spec.EnableDynamicWorker {
		sparseCluster := convertClusterSpecToSparseClusterSpec(cluster, rtype, int32(i))
		sparseTFConfig := SparseTFConfig{
			Cluster: sparseCluster,
			Task: TaskSpec{
				Type:  rtype,
				Index: int(i),
			},
		}
		tfConfigJSONByteSlice, err = json.Marshal(sparseTFConfig)
	} else {
		tfConfig := TFConfig{
			Cluster: cluster,
			Task: TaskSpec{
				Type:  rtype,
				Index: int(i),
			},
			// 我们需要设置环境为cloud,否则它会默认为local,这不是我们想要的。
			Environment: "cloud",
		}
		tfConfigJSONByteSlice, err = json.Marshal(tfConfig)
	}
	return string(tfConfigJSONByteSlice), nil
}

genClusterSpec将生成ClusterSpec

func genClusterSpec(tfjob *tfv1.TFJob) (ClusterSpec, error) {
    ... ...
	for rtype, spec := range tfjob.Spec.TFReplicaSpecs {
		port, err := GetPortFromTFJob(tfjob, rtype)
		// 这里循环生成了TF_CONFIG里面的Cluster信息。注意看注释,使用DNS配合Service,解决的还是各个节点IP不固定的问题
		for i := int32(0); i < *spec.Replicas; i++ {
			// 如下所述:https://kubernetes.io/docs/concepts/services-networking/dns-pos-service/#a-records。
			// Headless service为"my-svc.my-namespace.svc.cluster.local"的名称分配一个DNS记录。
			// 最后一部分是"svc.cluster.local"被称为cluster domain,在不同的kubernetes集群之间可能存在差异。
			hostName := jobcontroller.GenGeneralName(tfjob.Name, rt, fmt.Sprintf("%d", i))
			svcName := hostName + "." + tfjob.Namespace + "." + "svc"
			cluserDomain := os.Getenv(EnvCustomClusterDomain)
			if len(cluserDomain) > 0 {
				svcName += "." + cluserDomain
			}
			endpoint := fmt.Sprintf("%s:%d", svcName, port)
			replicaNames = append(replicaNames, endpoint)
		}
		clusterSpec[rt] = replicaNames
	}
	return clusterSpec, nil
}

pkg/control/pod_control.go:
使用集群的配置信息,真正启动Pod的创建:

func (r RealPodControl) CreatePodsWithControllerRef(...) error {
	... ...
	return r.createPods("", namespace, template, controllerObject, controllerRef)
}

调用K8S接口创建pod

func (r RealPodControl) createPods(...) error {
	pod, err := GetPodFromTemplate(template, object, controllerRef)
	... ...
	if newPod, err := r.KubeClient.CoreV1().Pods(namespace).Create(pod); err != nil {
		r.Recorder.Eventf(object, v1.EventTypeWarning, FailedCreatePodReason, "Error creating: %v", err)
		return err
	} 
    ... ...
}

pkg/controller.v1/tensorflow/service.go:
为每个给定的TFReplicaSpec检查和更新service

// 它将在创建/删除服务时发生错误时请求tfjob。
func (tc *TFController) reconcileServices(...) error {

	// 获取rt类型的所有service。
	services, err := tc.FilterServicesForReplicaType(services, rt)
	
	serviceSlices, servicesToBeRemoved := tc.GetServiceSlices(services, replicas, tflogger.LoggerForReplica(tfjob, rt))

	// 缩减
	if tfjob.Spec.EnableDynamicWorker && len(servicesToBeRemoved) > 0 {
		// 目前只允许缩小worker的service范围
		if rtype == tfv1.TFReplicaTypeWorker {
			for _, service := range servicesToBeRemoved {
				if err := tc.ServiceControl.DeleteService(tfjob.Namespace, service.Name, tfjob); err != nil {
					return err
				}
			}
		}
	}

	for index, serviceSlice := range serviceSlices {
		if len(serviceSlice) == 0 {
			err = tc.createNewService(tfjob, rtype, strconv.Itoa(index), spec)
			
		}
	}
}

为给定的indextype创建一个新service

func (tc *TFController) createNewService(tfjob *tfv1.TFJob, rtype tfv1.TFReplicaType, index string, spec *common.ReplicaSpec) error {
    ... ...
	expectationServicesKey := jobcontroller.GenExpectationServicesKey(tfjobKey, rt)
	err = tc.Expectations.ExpectCreations(expectationServicesKey, 1)
	
	// 创建 OwnerReference.
	controllerRef := tc.GenOwnerReference(tfjob)
    ... ...
    // 直接生成了Service的配置信息
	service := &v1.Service{
		Spec: v1.ServiceSpec{
			ClusterIP: "None",
			Selector:  labels,
			Ports: []v1.ServicePort{
				{
					Name: tfv1.DefaultPortName,
					Port: port,
				},
			},
		},
	}
    ... ...
	err = tc.ServiceControl.CreateServicesWithControllerRef(tfjob.Namespace, service, tfjob, controllerRef)
	... ...
}

pkg/control/service_control.go:
使用集群的配置信息,真正启动Service的创建:

func (r RealServiceControl) CreateServicesWithControllerRef(...) error {
	... ...
	return r.createServices(namespace, service, controllerObject, controllerRef)
}

调用K8S接口创建service

func (r RealServiceControl) createServices(namespace string, service *v1.Service, object runtime.Object, controllerRef *metav1.OwnerReference) error {
	serviceWithOwner, err := getServiceFromTemplate(service, object, controllerRef)
	... ...
	newService, err := r.KubeClient.CoreV1().Services(namespace).Create(serviceWithOwner)
	... ...
}

Good!要想真正搞懂Kubeflow,就必须要搞懂其核心TFJob的实现机制,如我们所见,TFJob代码量并不多,实现逻辑也不难掌握,以此为突破口,如果有必要,我们完全可以在参照它实现一套自己的定制化分布式训练框架。后续会有Kubeflow@Pipelines系列,如果本文对你有帮助,需要你的点赞收藏或直接关注我,会不定时更新技术干货和学习感悟,感谢支持~。

技术分享:

  1. Kubeflow-K8S的机器学习工具包,太牛了!
  2. 从原理到实战,彻底搞懂 Nginx!
  3. 掌握Shell编程,一篇就够了
  4. Kafka 概述:深入理解架构