ComfyUI 源码解析

1,297 阅读21分钟

前言

Stable Diffusion模型作为AIGC中的代表技术之一,加速了AIGC技术从研究阶段走向了应用阶段。目前有很多开源框架提供了基于Stable Diffusion的工作流,ComfyUI便是其中具有代表性的一个框架。

ComfyUI 是一个专为Stable Diffusion 设计的基于DAG的图形用户界面(GUI),简单来说就是将整个图像生成过程分解为多个独立的节点,每个节点都有自己独立的功能,例如加载模型,文本提示,生成图片等等。 每个模块通过输入和输出的线连在一起变成一个完整的工作流。正是因为ComfyUI简单易懂和极具扩展性的设计,目前在Stable Diffusion模型生产应用中得到了广泛的应用。

本文通过选取ComfyUI中代表性的处理流程进行分析,希望可以帮助读者理解ComfyUI中的各种核心实现,以及与Stable Diffusion算法库集成的各种细节,从而可以解决应用ComfyUI过程中出现的各种问题,或者基于ComfyUI进行二次开发工作。

概览

Repo: github.com/comfyanonymous/ComfyUI (commit: 2ec615)

核心代码文件如下:

File作用简介
execution.pyDAG的调度逻辑
node.pyDAG内置节点的实现
comfy/samplers.py核心采样算法前后处理相关逻辑
comfy/sd.py除SD采样算法外其他模块包含VAE、CLIP等算法模块的前后处理逻辑
comfy/lora.pylora模型加载的相关逻辑
comfy/controlnet.pycontrolnet等相关控制型算法模型的处理逻辑(包含T2IAdaptor)
comfy/model_base.py在开源SD代码的基础上进一步封装了diffusion的一些关键模型对象
comfy/model_management.py实现模型在不同设备上的加载、卸载与显存管理等功能
comfy/model_patcher.py对模型进行权重修改,支持LORA动态加载等功能

算法流程

算法参数

ComfyUI所有节点输入参数可从ComfyUI/nodes.py各个节点的定义中获取,这里按照ComfyUI默认的节点分类方式选取关键节点的算法参数进行解析:

Load LoRA

参数名类型取值范围默认值参数解释
strength_modelfloat[-20.0, 20.0, 0.01]1.0step = 0.01, 控制lora对基础模型的影响强度
strength_clipfloat[-20.0, 20.0, 0.01]1.0step = 0.01, 控制lora对clip模型的影响强度

CLIP Text Encode (Prompt)

参数名类型取值范围默认值参数解释
textstr//prompt提示词

KSampler

参数名类型取值范围默认值参数解释
prompttensor//Text Encoder编码之后的正向提示词
negative_prompttensor//Text Encoder编码之后的负向提示词
latent_imagetensor//VAE编码后输入图片的潜空间张量表示 (1,4,64, 64)
control_after_generate- randomize
  • Increment

  • decrement

  • fixed | | 提供每次请求改变seed的能力,支持随机、递增、递减和固定数值四种策略 | | seed | int | [0, +oo) | 0 | 种子数值,相同的种子和相同的promt会产生效果相同的图集 | | steps | int | [1, 10000] | 20 | Diffusion模型迭代降噪的次数 | | cfg | float | [0.0, 100.0, 0.5] | 8.0 | Classifier Free Guidance Scale控制提示词与出图相关性- 参数越大,生成的图像与文本提示的相关性越高,但可能会失真。

  • 数值越小,相关性则越低,越有可能偏离提示或输入图像,但质量越好。 | | sampler_name | comfy.samplers.KSampler.SAMPLERS | / | | 采样器(单次降噪步骤由采样器进行),常见采样器类型如下:- euler

  • euler_ancestral

  • helun

  • dpm_2

  • dpm_2_ancestral

  • lms

  • dpm_fast

  • dpm_adaptive

  • dpmpp_2s_ancestral

  • dpmpp_sde

  • dpmpp_sde_gpu

  • dpmpp_2m

  • dpmpp_2m_sde

  • dpmpp_2m_sde_gpu

  • dpmpp_3m_sde

  • dpmpp_3m_sde_gpu

  • ddpm

  • ddim

  • uni_pc

  • uni_pc_bh2 | | noise scheduler | comfy.samplers.KSampler.SCHEDULERS | / | | 噪音调度器,控制每个采样步骤噪声水平,第一步噪声最高,最后逐渐降至零,常见类型如下:- nomal

  • karras

  • exponential

  • sgm_uniform

  • simple

  • ddim_uniform | | denoise | float | [0.0, 1.0, 0.01] | 1.0 | 添加噪点的强度。噪点强度越高,AI的创作空间就越大,出图也就和原图越不相似 |

算法流程

使用如下简化工作流分析:

暂时无法在飞书文档外展示此内容

使用如下模型相关数据分析:

KSampler

该节点参数设置如下:

seed734689874417207
steps2
cfg8.0
sample_nameddim
schedulernormal
denoise1.0

函数调用链路如下:

暂时无法在飞书文档外展示此内容

node.py中的KSampler对象即为流程图中KSampler节点的实现入口,也是采样算法流程的起始入口:

# source file: nodes.py

class KSampler:
    def sample(...)  # 采样节点的流程入口
        return common_ksampler(...)  # 相关工作委托给common_ksampler函数
# source file: nodes.py
def common_ksample(...):
    if disable_noise:
        noise = torch.zeros(...)
    else:
        # 从标准正态分布产生初始随机噪声
        noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)  
    ...
    
    samples = comfy.sample.sample(...) # 采样核心流程入口
    
    return (out, )

comfy/sample.py 中定义了采样前的一些通用前置工作,包括模型加载、产生初始噪声等,其中的sampler函数仍然是进行一些前置工作,核心流程转发给了comfy/samplers.py中定义的采样器类。

# source_file: comfy/sample.py

...

def prepare_sampling(...):
    comfy.model_management.load_models_gpu(...) # 模型加载
    
    # 参数加载
    positive_copy = broadcast_cond(positive, noise_shape[0], device)
    negative_copy = broadcast_cond(negative, noise_shape[0], device)
    
    return real_model, positive_copy, negative_copy, noise_mask, models

...

def sample(...):
    ... = prepare_sampling(...)   # 采样前的准备工作,包括模型加载、参数加载等
    
    ...
    
    sampler = comfy.samplers.KSampler(...)  # 采样器类实例化
    samples = sampler.sample(...)  # 核心采样工作转发给采样器类的sample函数
    
    ...
    
    return samples

KSampler.sample()->sample()仍然是做一些准备工作,并继续转发请求至由用户参数选择的DDIM采样器(DDIM类)

# source_file: comfy/samplers.py

...

def sample(...):
    ...
    # 根据输入的参数选择DDIM Sampler, 最终实际采样工作转发给了DDIM类
    samples = sampler.sample(...)  
    return model.process_latent_out(samples.to(torch.float32))

...

Class KSampler:
    def __init__(self, ...):
        ...
        self.scheduler = scheduler
        self.set_steps(steps, denoise)
        
        
    def set_steps(self, ...)
        ...
        # 根据用户参数中的scheduer类型和步数预产生每一步的sigmas值,从而控制噪声采样水平
        self.sigmas = self.calcuate_sigmas(steps).to(self.device)
        ...
    
    ...
    def sample(self, ...):
        ...
        sampler = sampler_class(self.sampler) 
        return sample(...)

DDIM类继承了Sampler基类,重新实现了sampler方法,并最终把采样工作转发给了statble diffusion源码中的ddim_sampler类。

# source_file: comfy/samplers.py

class DDIM(Sampler):
    def sample(...):
        ...
        ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
        ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
        z_enc = ddim_sampler.stochastic_encode(...)
        
        # 最终所有的流程都转发给了Stable Diffusion源码中的ddim_sampler类
        samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
                                                batch_size=noise.shape[0],
                                                shape=noise.shape[1:],
                                                verbose=False,
                                                eta=0.0,
                                                x_T=z_enc,
                                                x0=latent_image,
                                                img_callback=ddim_callback,
                                                denoise_function=model_wrap.predict_eps_discrete_timestep,
                                                extra_args=extra_args,
                                                mask=noise_mask,
                                                to_zero=sigmas[-1]==0,
                                                end_step=sigmas.shape[0] - 1,
                                                disable_pbar=disable_pbar)
        return samples

以下流程基本就是Stable Difussion公布的官方源码中的相关流程,核心流程分析如下:

# source_file: comfy/ldm/models/diffusion/ddim.py

class DDIMSampler(object):
    ...
    def sampler_custom(self, ...):
        self.make_schedule_timesteps(...)
        samples, intermediates = self.ddim_sampling(...) # 核心采样过程
        
        return samples, intermediates
        
        
    def ddim_sampling(self, ...):
        if x_T is None:
            img = torch.randn(shape, device=device)  # text2img
        else:
            img = x_T # img2img
        ...
        
        # DDIM多步骤采样过程
        for i, step in enumerate(iterator):
            outs = self.p_sample_ddim(...) # 调用每一步的具体采样方法
            
        return img, intermediates
        
        
    def p_sample_ddim(self, ...):
        ...
        
        # 以下代码为通过UNet模型预测噪声
        if denoise_function is None:
             # 此步骤会调用ComfyUI/comfy/k_diffusion/external.py:predict_eps_discrete_timestep()
             model_output = denoise_function(x, t, **extra_args) 
        elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
             model_output = self.model.apply_model(x, t, c)
        else:
             ...
             model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
             model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
        
        ...
        
        
        # 以下代码是Diffusion论文中对应公式的展开
        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # 由Xt获取Xt-1,即图片完成一步去噪步骤
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        
        return x_prev, pred_x0

CLIP Text Encode

函数调用链路如下:

暂时无法在飞书文档外展示此内容

node.py中的CLIPTextEncode对象即为流程图中CLIPTextEncode节点执行入口:

# source file: nodes.py

class CLIPTextEncode:
    ...
    
    def encode(self, clip, text):
        tokens = clip.tokenize(text)  # 输入切分为token
        cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) # clip模型对预处理过的输入进行编码
        return ([[cond, {"pooled_output": pooled}]],)

其中clip对象定义在comfy/sd.py文件中,核心算法流程都通过clip对象调用:

  • clip.tokenize(...): 输入文本切分为token的过程

  • clip.encode_from_tokens(...): clip模型推理过程。

我们先看下clip.tokenize(...)的过程,其中self.tokenizer为定义在comfy/sdxl_clip.py中的SDXLTokenizer对象。

# source file: comfy/sd.py

class CLIP:
    ...
    
    def tokenize(self, ...)
        # self.tokenizer是SDXLTokenizer的实例化对象
        return self.tokenizer.tokenize_with_weights(text, return_word_ids)

SDXL 使用两个CLIP模型(clip_l 和 clip_g)来进行文本编码, 最终的结果会合并使用:

# source file: comfy/sdxl_clip.py

class SDXLTokenizer(sd1_clip.SD1Tokenizer):
    def __init__(...):
        self.clip_l = sd1_clip.SD1Tokenizer(...)
        self.clip_g = SDXLClipGTokenizer(...)
        
    def tokenize_with_weights(self, ...)  # 最终调用SD1Tokenizer.tokenize_with_weights(...)
        out = {}
        out["g"] = self.clip_g.tokenize_with_weights(...)
        out["l"] = self.clip_l.tokenize_with_weights(...)
        return out
# source file: comfy/sd1_clip.py
class SD1Tokenizer:
   ...
   def tokenize_with_weights(self, ...)
       '''
        Takes a prompt and converts it to a list of (token, weight, word id) elements.
        Tokens can both be integer tokens and pre computed CLIP tensors.
        Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
        Returned list has the dimensions NxM where M is the input size of CLIP
        '''
        ...
        
        text = escape_import(text)
        parsed_wieght = token_weights(text, 1.0)
        ...
        
        for weighted_segment, weight in parsed_weights:
            to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
            ...
            for word in to_tokenize:
                ...
                # 调用transformer包中的CLIPTokenizer相关方法来切分文本
                tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
                
            ...
            
        return batched_tokens

对文本Token化完成后,则是调用CLIP模型对预处理的Token化文本进行编码的过程:

# source file: comfy.py/sd.py

class CLIP:
    ...
    def encode_from_tokens(self, ...):
        if self.layer_idx is not None:
            self.cond_stage_model.clip_layer(self.layer_idx)
        else:
            self.cond_stage_model.reset_clip_layer()

        # 最终调用comfy/model_management模块加载CLIP模型
        self.load_model()
        
        # 调用SDXLClipModel.encode_token_weights(...)
        cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
        
        # 返回结果
        if return_pooled:
            return cond, pooled
        return cond
# source file: comfy/sdxl_clip.py

class SDXLClipModel(torch.nn.Module):
    ...
    def reset_clip_layer(self):
        self.clip_g.reset_clip_layer()
        self.clip_l.reset_clip_layer()
        
    ...
    
    def encode_token_weights(self, token_weight_pairs):
        token_weight_pairs_g = token_weight_pairs["g"]
        token_weight_pairs_l = token_weight_pairs["l"]
        # 调用comfy/sd1_clip.py模块的ClipTokenWeightEncoder(...)
        g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
        l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
        return torch.cat([l_out, g_out], dim=-1), g_pooled

最终Clip模型的编码过程由SD1ClipModel对象的forward方法来实现:

# source file: comfy/sd1_clip.py

class ClipTokenWeightEncoder:
    def encode_token_weights(self, token_weight_pairs):
        ...
        # 调用SD1ClipModel的encode方法
        out, pooled = self.encode(to_encode)
        
        
...

class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    ...
    
    # CLIP模型推理代码
    def forward(self, tokens):
        ...
        # self.transfomer即为transformer模块中的ClipTextModel对象,用于执行模型的推理等。
        outputs = self.transformer(...)
        ...
        
        return z.float(), pooled_output.float()
    
    def encode(self, tokens):
        return self(tokens)

VAE Encoder/Decoder

函数调用链路如下:

暂时无法在飞书文档外展示此内容

暂时无法在飞书文档外展示此内容

node.py中的VAEEncode对象和VAEDecode对象即为流程图中VAE Encode(Promt)和VAE Decode节点执行入口:

# source file: nodes.py

class VAEDecode:
    ...
    def decode(self, vae, samples):
        # 调用comfy/sd.py::VAE对象
        return (vae.decode(samples["samples"]), )


class VAEEncode:
    ...
    def encode(self, ...):
        # 调用同类中的方法对像素进行裁剪,使其符合vae模型的输入
        pixels = self.vae_encode_crop_pixels(pixels)
        # 调用comfy/sd.py::VAE对象
        t = vae.encode(pixels[:, :, :, :3])
        return ({"samples": t},)
        
        

从下面的comfy/sd.py中的VAE对应的decode和encode方法的实现可以看出,最终关键的算法流程都委托给了sd开源代码中的AutoencoderKL对象来执行。

# source file: comfy/sd.py

class VAE:
    def __init__(self, ...)
        ...
        # 模型推理相关功能委托给了sd开源代码中的AutoencoderKL对象
        self.first_stage_model = AutoencoderKL(...)
        ...
        
    def decode(self, ...)
        ...
        for x in range(0, samples_in.shape[0], batch_number):
            ...
            samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
            pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
            ...
   
        
    def encode(self, ...)
        ...
        for x in range(0, pixel_samples.shape[0], batch_number):
            ...
            samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
            ...

LORA

node.py中的LoraLoader对象Lora加载执行入口:

# source file: node.py

class LoraLoader:
    ...
    def load_lora(self, ...):
        ...
        model_lora, clip_lora = comfy.sd.load_lora_for_models(
            model, clip, lora, strength_model, strength_clip
        )
        return (model_lora, clip_lora)
# source file: comfy/sd.py

def load_lora_for_models(model, clip, lora, strength_model, strenght_clip):
    # 获取unet模型中需要被lora更新的参数名称列表
    key_map = comfy.lora.model_lora_keys_unet(model.model)
    # 获取clip模型中需要被lora更新的参数名称列表
    key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
    
    # 按照参数名称列表加载lora相关参数
    loaded = comfy.lora.load_lora(lora, key_map)
    new_modelpatcher = model.clone()
    # 返回model patche数据,最终节点运行加载模型时执行patch融合 (调用comfy/model_patcher.py::ModelPatcher.add_patchers)
    k = new_modelpatcher.add_patches(loaded, strength_model)
    new_clip = clip.clone()
    # 返回model patche数据,最终节点运行加载模型时执行patch融合 (调用comfy/model_patcher.py::ModelPatcher.add_patchers)
    k1 = new_clip.add_patches(loaded, strength_clip)
    k = set(k)
    k1 = set(k1)
    for x in loaded:
        if (x not in k) and (x not in k1):
            print("NOT LOADED", x)

    return (new_modelpatcher, new_clip)

lora加载相关的逻辑都集中在comfy/lora.py中:

# source file: comfy/lora.py

def model_lora_keys_unet(model, key_map={}):
    sdk = model.state_dict().keys()

    # 根据特定的模式匹配规则构建Lora模型和Unet模型之间的参数名称映射
    for k in sdk:
        if k.startswith("diffusion_model.") and k.endswith(".weight"):
            key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
            key_map["lora_unet_{}".format(key_lora)] = k
       
    # 从unet模型中匹配diffusers模型相关key
    diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
    for k in diffusers_keys:
        if k.endswith(".weight"):
            unet_key = "diffusion_model.{}".format(diffusers_keys[k])
            key_lora = k[:-len(".weight")].replace(".", "_")
            key_map["lora_unet_{}".format(key_lora)] = unet_key

            diffusers_lora_prefix = ["", "unet."]
            for p in diffusers_lora_prefix:
                diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
                if diffusers_lora_key.endswith(".to_out.0"):
                    diffusers_lora_key = diffusers_lora_key[:-2]
                key_map[diffusers_lora_key] = unet_key
    return key_map
    
 
 def model_lora_keys_clip(model, key_map={}):
     ...
     
 

模型管理

模型管理相关启动参数如下:

参数名称作用备注
--gpu-onlyStore and run everything (text encoders/CLIP models, etc... on the GPU).
--high-vramBy default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.
--normal-vramUsed to force normal vram use if lowvram gets automatically enabled.
--low-vramSplit the unet in parts to use less vram.使用Accelerate库优化大显存模型的管理
--no-vramWhen lowvram isn't enough.
--cpuTo use the CPU for everything (slow).
--disable-smart-memoryForce ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.

模型管理相关代码分布在以下四个文件中:

  • comfy/model_base.py: 在开源sd代码的基础上进一步封装了diffusion的一些关键模型对象
  • comfy/model_detection.py :对输入的一些模型属性进行分析
  • comfy/model_management.py: 实现模型在不同设备上的加载与显存管理等功能
  • comfy/model_patcher.py: 对模型进行权重更新,支持LORA动态加载等功能

模型加载和卸载

comfyui核心模型的加载都会以model_management.py::load_models_gpu()为入口,所有模型的卸载都会以model_management.py::cleanup_models()为入口;重点源码分析如下:

# source file: comfy/model_management.py

...
# 核心数据结构,所有加载到gpu的模型都会存到此数组中;
current_loaded_models = []
...

# current_loaded_models中存的对象就是LoadedModel对象
class LoadedModel:
    def __init__(self, model):
        self.model = model  # model_patcher.py::ModelPatcher对象
        self.model_accelerated = False
        self.device = model.load_device
        
    ...
    
    # 模型加载逻辑
    def model_load(self, lowvram_model_memory=0):
       # 调用ModelPather对象的方法进行模型加载
       self.model.model_patches_to(self.device)
       self.model.model_patches_to(self.model.model_dtype())  
       ...
       # 如果显存不够,则借助accelerate库,进行多级缓存分布加载;
       if lowvram_model_memory > 0:
            print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
            device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
            accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
            self.model_accelerated = True
            
       return self.real_model
      
      ...
       
     # 模型卸载逻辑
     def model_unload(self):
        if self.model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(self.real_model)
            self.model_accelerated = False
        # 调用ModelPather对象的方法进行模型卸载
        self.model.unpatch_model(self.model.offload_device)
        self.model.model_patches_to(self.model.offload_device)
       ...
...

# 模型加载入口函数
def load_models_gpu(models, memory_required=0):
    ...
    models_to_load = []
    
    for x in models:
        loaded_model = LoadedModel(x)
        ...
        models_to_load.append(loaded_model)
    ...
    for loaded_model in models_to_load:
        ...
        # 调用LoadedModel对象的model_load方法进行模型加载
        cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
        current_loaded_models.insert(0, loaded_model)
    return
    
...

# 模型卸载入口函数
def cleanup_models():
    to_delete = []
    for i in range(len(current_loaded_models)):
        # 通过引用计数来判断是否需要卸载模型
        if sys.getrefcount(current_loaded_models[i].model) <= 2:
            to_delete = [i] + to_delete

    for i in to_delete:
        # 清理已加载模型数组
        x = current_loaded_models.pop(i)
        x.model_unload()
        del x
        
...

模型权重更新

其中model_patcher.py中实现了对模型参数进行动态修改的各种方法,核心方法如下:

# source file: comfy/model_patcher.py

class ModelPatcher:
    ...
    # 对多份模型权重进行合并
    def patch_model(self, device_to=None):
        model_sd = self.model_state_dict()
        for key in self.patches:
            if key not in model_sd:
                print("could not patch. key doesn't exist in model:", key)
                continue

            weight = model_sd[key]

            if key not in self.backup:
                self.backup[key] = weight.to(self.offload_device)

            if device_to is not None:
                temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
            else:
                temp_weight = weight.to(torch.float32, copy=True)
            out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
            comfy.utils.set_attr(self.model, key, out_weight)
            del temp_weight

        if device_to is not None:
            self.model.to(device_to)
            self.current_device = device_to

        return self.model
    ...
    
    def calculate_weight(self, patches, weight, key):
        for p in patches:
            alpha = p[0]
            v = p[1] # 权重
            strength_model = p[2] # 强度
            if strength_model != 1.0:
                weight *= strength_model # 如果强度不为1.0,则重新计算权重
                
            ...
            
            if len(v) == 1: # 
                ...
            elif len(v) == 4: # lora/locon
                ...
                # lora/locon 融入后模型相关的线性层更新
                weight += (
                    (
                        alpha
                        * torch.mm(
                            mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
                        )
                    )
                    .reshape(weight.shape)
                    .type(weight.dtype)
                )
                ...
            elif len(v) == 8: #lokr
                ...
            else: # loha
                ...
            
            return weight
            
            

由此可以看出comfyui支持lora模型动态加载的核心思路是对基模的参数进行备份,需要替换时恢复即可。

模型管理机制

  • ComfyUI自带的model loader大部分会先把模型加载到内存中,然后在运行时按需加载到GPU(KSampler节点运行前);

  • Controlnet模型在ControlNetLoader节点会加载到内存中,然后在ControlnetApply节点基于strength生成一个新的模型副本,最终在KSampler节点统一加载到GPU中;

  • CheckpointLoaderSimple节点在一些条件下底模会直接加载到显存中,后续若底模引用发生变化(比如加了LoRA相关Patch),则会在KSampler节点卸载原底模,加载新底模;

  • 核心模型加载到GPU上都是通过comfy/model_management.py::load_models_gpu()API来实现,最终加载到GPU上的模型都通过comfy/model_management.py::current_loaded_models 数组来缓存。

  • 核心模型从GPU卸载到Memory,会调用comfy/model_management.py::cleanup_models()API来实现;卸载时机是每次进行workflow的调度前,会检查comfy/model_management.py::current_loaded_models 数组中每个模型的引用计数,如果无其他对象引用该模型,则会进行模型卸载。

  • ComfyUI每次调度workflow过程中每个节点的执行结果会被缓存,下次请求执行时会通过节点参数以及前置依赖节点的参数是否变化或者节点的ISCHANGED()方法来判断是否清理缓存和重新执行当前节点。节点的缓存清理和模型的引用计数相关联

Runtime流程

接口

默认类型

type备注
INT基础类型
FLOAT基础类型
STRING基础类型
IMAGE图片像素
MASK图片遮挡部分
MODEL模型类型
LATENT经过VAE编码的潜空间图片表示
CONDITIONING/
VAEVAE模型
CLIPCLIP模型
CLIP_VERSION/
CLIP_VISION_OUTPUT/
CONTROL_NET/
STYLE_MODEL/
GLIGENGLIGEN模型

通过下面对节点输入参数进行检查的代码分析可知,除了3个基本类型,ComfyUI只是简单的通过类型的字符串名来校验不同节点的输入和输出是否匹配,因此可以扩展自定义类型。

# source file : execution.py

def validate_inputs(prompt, item, validated):
    ...
    # 获取当前节点的输入节点信息
    # eg: inputs = {'filename_prefix': 'ComfyUI', 'images': ['8', 0]}
    inputs = prompt[unique_id]["inputs"]
    ...
    # 获取节点输入参数的required字段定义
    # eg: {'images': ('IMAGE',), 'filename_prefix': ('STRING', {'default': 'ComfyUI'})}
    required_inputs = class_inputs['required']
    ...
    for x in required_inputs:
        ...
        val = inputs[x]
        ...
        # 如果val是list,说明当前的入参由前置节点传入
        if isinstance(val, list):
            ...
            # 获取前置节点编号
            o_id = val[0]
            # 获取前置节点类型
            o_class_type = prompt[o_id]["class_type"]
            # 获取前置节点返回类型
            r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
            # 检查前置节点的返回类型和当前节点的输入类型是否匹配
            if r[val[1]] != type_input:
                ...
                errors.append(err)
                continue
             ...
             # 对前置节点递归调用
             r = validate_inputs(prompt, o_id, validated)
             
          else: # 当前的入参非前置节点传入
              ...
              # 尝试按3种基本类型转化
              try:
                if type_input == "INT":
                    val = int(val)
                    inputs[x] = val
                if type_input == "FLOAT":
                    val = float(val)
                    inputs[x] = val
                if type_input == "STRING":
                    val = str(val)
                    inputs[x] = val
                except Exception as ex:
                    ...
                if hasattr(obj_class, "VALIDATE_INPUTS"):
                    ...
        
        
        
def validate_prompt(prompt):
    outputs = set()
    
    # 通过遍历找到输出节点
    for x in prompt:
        ...
        if hasattr(class_, "OUTPUT_NODE") and class_.OUTPUT_NODE == True:
            outputs.add(x)
            
    ...
    for o in outputs:
        ...
        # 对输出节点调用validate_inputs函数进行校验 (实际是从输出节点倒序递归检查所有相连的节点)
        validate_inputs(prompt, o, validated)
        
    return (True, None, ...)
        

自定义节点

以下是ComfyUI源码custom_node/example_node.py.example 提供的教程:

class Example:
    """
    A example node

    Class methods
    -------------
    INPUT_TYPES (dict): 
        Tell the main program input parameters of nodes.

    Attributes
    ----------
    RETURN_TYPES (`tuple`): 
        The type of each element in the output tulple.
    RETURN_NAMES (`tuple`):
        Optional: The name of each output in the output tulple.
    FUNCTION (`str`):
        The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
    OUTPUT_NODE ([`bool`]):
        If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
        The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
        Assumed to be False if not present.
    CATEGORY (`str`):
        The category the node should appear in the UI.
    execute(s) -> tuple || None:
        The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
        For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
    """
    def __init__(self):
        pass
    
    @classmethod
    def INPUT_TYPES(s):
        """
            Return a dictionary which contains config for all input fields.
            Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
            Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
            The type can be a list for selection.

            Returns: `dict`:
                - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
                - Value input_fields (`dict`): Contains input fields config:
                    * Key field_name (`string`): Name of a entry-point method's argument
                    * Value field_config (`tuple`):
                        + First value is a string indicate the type of field or a list for selection.
                        + Secound value is a config for type "INT", "STRING" or "FLOAT".
        """
        return {
            "required": {
                "image": ("IMAGE",),
                "int_field": ("INT", {
                    "default": 0, 
                    "min": 0, #Minimum value
                    "max": 4096, #Maximum value
                    "step": 64, #Slider's step
                    "display": "number" # Cosmetic only: display as "number" or "slider"
                }),
                "float_field": ("FLOAT", {
                    "default": 1.0,
                    "min": 0.0,
                    "max": 10.0,
                    "step": 0.01,
                    "round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
                    "display": "number"}),
                "print_to_screen": (["enable", "disable"],),
                "string_field": ("STRING", {
                    "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
                    "default": "Hello World!"
                }),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    #RETURN_NAMES = ("image_output_name",)

    FUNCTION = "test"

    #OUTPUT_NODE = False

    CATEGORY = "Example"

    def test(self, image, string_field, int_field, float_field, print_to_screen):
        if print_to_screen == "enable":
            print(f"""Your input contains:
                string_field aka input text: {string_field}
                int_field: {int_field}
                float_field: {float_field}
            """)
        #do some processing on the image, in this example I just invert it
        image = 1.0 - image
        return (image,)


# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
    "Example": Example
}

# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
    "Example": "Example Node"
}

调度

单请求

# source file : execution.py

...
 def recursive_execute(...)
     # 当前执行的节点
     unique_id = current_item
     # 获取当前节点的输入节点信息
     # eg: inputs = {'filename_prefix': 'ComfyUI', 'images': ['8', 0]}
     inputs = prompt[unique_id]["inputs"]
     # 获取当前节点的节点类型
     class_type = prompt[unique_id]["class_type"]
     class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
     
     for x in inputs:
         input_data = inputs[x]
         # 如果input_data是个list,说明入参由依赖节点传入
         if isinstance(input_data, list):
             if input_unique_id not in outputs:
                 # 递归执行入参依赖节点
                 recursive_execute(...)
                 ...
     ...
     # 准备当前节点输入数据
     get_input_data(...)
     ...
     
     # 实际调用节点执行
     map_node_over_list(...)
     
     # 汇总当前节点计算结果
     get_output_data(...)
     outputs[unique_id] = output_data
     ...


class PromptExecutor:
    def execute(...)
        ...
        # 将输出节点添加到to_execute中,本例中只有[(0, '9')]
        for node_id in list(execute_outputs):
            to_execute += [(0, node_id)]
            
        while len(to_execute) > 0:
            # 按输出节点依赖的未执行节点数排序
            to_execute = sorted(
                list(
                    map(
                        lambda a: (
                            len(
                                recursive_will_execute(prompt, self.outputs, a[-1])
                            ),
                            a[-1],
                        ),
                        to_execute,
                    )
                )
            )
            # 获取依赖的未执行节点数最少得输出节点,
            output_node_id = to_execute.pop(0)[-1]
            
            # 从当前选择的输出节点递归执行
            recursive_execute(...)
                 

节点间的数据传输流程如下:

# source file : execution.py

def get_input_data(...):
    ...
    input_data_all = {}
    # 有可能有多个输入,对应到图上节点上有多个端点
    for x in inputs:
        # 有前置节点
        if isinstance(input_data, list):
            ...
            # 从全局缓存中获取前置节点的输出
            obj = outputs[input_unique_id][output_index]
            input_data_all[x] = obj   
        else: # 直接输入
            ...
            input_data_all[x] = [input_data] #
    
    return input_data_all

def get_output_data(...):
    ...

# obj: 待执行节点对应的节点类型
# input_data_all: 所有输入数据
# func: 节点类型实际需要执行的函数    
def map_node_over_list(obj, input_data_all, func, ...):
    # 判断当前节点类型是否有定义INPUT_IS_LIST属性
    input_is_list = False
    if hasattr(obj, "INPUT_IS_LIST"):
    input_is_list = obj.INPUT_IS_LIST
     
    ...
    
    
    if input_is_list:  # 如果节点定义了INPUT_IS_LIST属性,证明节点支持LIST输入的处理,直接传入
        ...
        results.append(getattr(obj, func)(**input_data_all))
    elif max_len_input == 0:
        ...
        results.append(getattr(obj, func)())
    else: # 否则,在最外层使用for循环展开输入,多次调用该节点
        for i in range(max_len_input):
            ...
            results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
            
    return results

高并发

同请求batch

单请求数据调度相关代码分析可知,计算图中多个节点间的数据调度,没有任何限制,只要类型匹配即可。对于batch数据的情况分两种情况:

  • 如果待执行节点类型定义了INPUT_IS_LIST属性,则直接将BATCH数据传入该节点,由节点内部自行决定如何进行Batch数据处理

  • 如果待执行节点类型未定义INPUT_IS_LIST属性,则会将BATCH数据拆开,多次循环调度该节点进行处理。

搜索 ComfyUI源码,可以看到和batch相关的节点目前只位于latent/batch目录下:

  • 目前只有LatentRebatch节点配置了INPUT_IS_LIST属性,可以理解为batch节点,用于将多个latent输入聚合为batch输出。

  • LatentFromBatch和RepeatLatentBatch的节点执行函数虽然有对多个输入的处理,但实际只是普通的节点,多个输入被包装为普通的输入传入,然后内部再处理。

不同请求batch

由于ComfyUI一开始的定位类似于webui,结合代码分析可知 ComfyUI不支持对多请求组batch.

线程安全性

所有计算图的状态都维护在PromptExecutor中,未看到全局变量的使用:

# file: execution.py

...

class PromptExecutor:
    def __init__(self, server):
        self.outputs = {}
        self.object_storage = {}
        self.outputs_ui = {}
        self.old_prompt = {}
        self.server = server
     ...

因此如果每个线程都有各自的PromptExecutor对象,高并发情况下是线程安全的。

Reference