Skip to content

Commit

Permalink
Add some GGUF model detection
Browse files Browse the repository at this point in the history
- this makes eg. SD3.5 medium GGUF work
  • Loading branch information
Acly committed Nov 29, 2024
1 parent fbe3781 commit a1571c2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generative AI plugin for Krita"""

__version__ = "1.28.1"
__version__ = "1.29.0"

import importlib.util

Expand Down
22 changes: 11 additions & 11 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from collections import deque
from itertools import chain, product
from typing import NamedTuple, Optional, Sequence
from typing import Any, NamedTuple, Optional, Sequence

from .api import WorkflowInput
from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels
Expand Down Expand Up @@ -163,6 +163,7 @@ async def connect(url=default_url, access_token=""):
# Retrieve list of checkpoints
checkpoints = await client.try_inspect("checkpoints")
diffusion_models = await client.try_inspect("diffusion_models")
diffusion_models.update(await client.try_inspect("unet_gguf"))
client._refresh_models(nodes, checkpoints, diffusion_models)

# Check supported SD versions and make sure there is at least one
Expand Down Expand Up @@ -369,11 +370,11 @@ async def disconnect(self):
self._unsubscribe_workflows(),
)

async def try_inspect(self, folder_name: str):
async def try_inspect(self, folder_name: str) -> dict[str, Any]:
try:
return await self._get(f"api/etn/model_info/{folder_name}")
except NetworkError:
return None # server has old external tooling version
return {} # server has old external tooling version

@property
def queued_count(self):
Expand All @@ -384,11 +385,13 @@ def is_executing(self):
return self._active is not None

async def refresh(self):
nodes, checkpoints, diffusion_models = await asyncio.gather(
nodes, checkpoints, diffusion_models, diffusion_gguf = await asyncio.gather(
self._get("object_info"),
self.try_inspect("checkpoints"),
self.try_inspect("diffusion_models"),
self.try_inspect("unet_gguf"),
)
diffusion_models.update(diffusion_gguf)
self._refresh_models(nodes, checkpoints, diffusion_models)

def _refresh_models(self, nodes: dict, checkpoints: dict | None, diffusion_models: dict | None):
Expand All @@ -407,7 +410,7 @@ def parse_model_info(models: dict, model_format: FileFormat):
return {
filename: CheckpointInfo(filename, arch, model_format)
for filename, arch, is_inpaint, is_refiner in parsed
if not (arch is None or is_inpaint or is_refiner)
if not (arch is None or (is_inpaint and arch is not Arch.flux) or is_refiner)
}

if checkpoints:
Expand All @@ -424,12 +427,9 @@ def parse_model_info(models: dict, model_format: FileFormat):
models.loras = nodes["LoraLoader"]["input"]["required"]["lora_name"][0]

if gguf_node := nodes.get("UnetLoaderGGUF", None):
gguf_models = {
name: CheckpointInfo(name, Arch.flux, FileFormat.diffusion)
for name in gguf_node["input"]["required"]["unet_name"][0]
}
models.checkpoints.update(gguf_models)
log.info(f"GGUF support: {len(gguf_models)} models found.")
for name in gguf_node["input"]["required"]["unet_name"][0]:
if name not in models.checkpoints:
models.checkpoints[name] = CheckpointInfo(name, Arch.flux, FileFormat.diffusion)
else:
log.info(f"GGUF support: node is not installed.")

Expand Down
9 changes: 5 additions & 4 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

# Version identifier for all the resources defined here. This is used as the server version.
# It usually follows the plugin version, but not all new plugin versions also require a server update.
version = "1.28.0"
version = "1.29.0"

comfy_url = "https://github.com/comfyanonymous/ComfyUI"
comfy_version = "61196d88576c95c1cd8535e881af48172d5af525"
comfy_version = "bf2650a80e5a7a888da206eab45c53dbb22940f7"


class CustomNode(NamedTuple):
Expand Down Expand Up @@ -39,7 +39,7 @@ class CustomNode(NamedTuple):
"External Tooling Nodes",
"comfyui-tooling-nodes",
"https://github.com/Acly/comfyui-tooling-nodes",
"e10daee9edea458fc709f60e725970a25567fca4",
"d7d421baaa7d3140fd7fc500d928244045211217",
["ETN_LoadImageBase64", "ETN_LoadMaskBase64", "ETN_SendImageWebSocket", "ETN_Translate"],
),
CustomNode(
Expand All @@ -56,7 +56,7 @@ class CustomNode(NamedTuple):
"GGUF",
"ComfyUI-GGUF",
"https://github.com/city96/ComfyUI-GGUF",
"8e898fad4caab59bf4144e0cf11978b893de7e54",
"4a8432884167f2526d60ef36e985bdabebb9e1e0",
["UnetLoaderGGUF", "DualCLIPLoaderGGUF"],
)
]
Expand Down Expand Up @@ -939,6 +939,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
resource_id(ResourceKind.text_encoder, Arch.all, "t5"): ["t5"],
resource_id(ResourceKind.vae, Arch.sd15, "default"): ["vae-ft-mse-840000-ema"],
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
resource_id(ResourceKind.vae, Arch.sd3, "default"): ["sd3"],
resource_id(ResourceKind.vae, Arch.flux, "default"): ["ae.s"],
}
# fmt: on
Expand Down

0 comments on commit a1571c2

Please sign in to comment.