Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions Header/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ enum sd_log_level_t {
SD_LOG_ERROR
};

enum preview_t {
PREVIEW_NONE,
PREVIEW_PROJ,
PREVIEW_TAE,
PREVIEW_VAE,
PREVIEW_COUNT
};

enum lora_apply_mode_t {
LORA_APPLY_AUTO,
LORA_APPLY_IMMEDIATELY,
LORA_APPLY_AT_RUNTIME,
LORA_APPLY_MODE_COUNT,
};

typedef struct {
bool enabled;
int tile_size_x;
Expand Down Expand Up @@ -157,11 +172,13 @@ typedef struct {
enum sd_type_t wtype;
enum rng_type_t rng_type;
enum prediction_t prediction;
enum lora_apply_mode_t lora_apply_mode;
bool offload_params_to_cpu;
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool diffusion_flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;
bool vae_conv_direct;
bool force_sdxl_vae_conv_scale;
Expand Down Expand Up @@ -254,9 +271,11 @@ typedef struct sd_ctx_t sd_ctx_t;

typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
typedef void (*sd_preview_cb_t)(int step, int frame_count, sd_image_t* frames, bool is_noisy);

SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, enum preview_t mode, int interval, bool denoised, bool noisy);
SD_API int32_t get_num_physical_cores();
SD_API const char* sd_get_system_info();

Expand All @@ -270,6 +289,10 @@ SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_schedule(const char* str);
SD_API const char* sd_prediction_name(enum prediction_t prediction);
SD_API enum prediction_t str_to_prediction(const char* str);
SD_API const char* sd_preview_name(enum preview_t preview);
SD_API enum preview_t str_to_preview(const char* str);
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);

SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
Expand Down
8 changes: 8 additions & 0 deletions StableDiffusion.NET/Enums/LoraApplyMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace StableDiffusion.NET;

public enum LoraApplyMode
{
Auto,
Immediately,
AtRuntime
}
9 changes: 9 additions & 0 deletions StableDiffusion.NET/Enums/Preview.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace StableDiffusion.NET;

public enum Preview
{
None,
Proj,
TAE,
VAE
}
15 changes: 15 additions & 0 deletions StableDiffusion.NET/EventArgs/StableDiffusionPreviewEventArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;
using HPPH;

namespace StableDiffusion.NET;

public sealed class StableDiffusionPreviewEventArgs(int step, bool isNoisy, Image<ColorRGB> image) : EventArgs
{
#region Properties & Fields

public int Step { get; } = step;
public bool IsNoisy { get; } = isNoisy;
public Image<ColorRGB> Image { get; } = image;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public sealed class DiffusionModelParameter
/// </summary>
public bool FlashAttention { get; set; } = false;

public bool TaePreviewOnly { get; set; } = false;

/// <summary>
/// use Conv2d direct in the diffusion model
/// This might crash if it is not supported by the backend.
Expand All @@ -91,6 +93,8 @@ public sealed class DiffusionModelParameter

public Prediction Prediction { get; set; } = Prediction.Default;

public LoraApplyMode LoraApplyMode { get; set; } = LoraApplyMode.Auto;

/// <summary>
/// quantizes on load
/// not really useful in most cases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ public static Native.Types.sd_ctx_params_t ConvertToUnmanaged(DiffusionModelPara
wtype = managed.Quantization,
rng_type = managed.RngType,
prediction = managed.Prediction,
lora_apply_mode = managed.LoraApplyMode,
offload_params_to_cpu = (sbyte)(managed.OffloadParamsToCPU ? 1 : 0),
keep_clip_on_cpu = (sbyte)(managed.KeepClipOnCPU ? 1 : 0),
keep_control_net_on_cpu = (sbyte)(managed.KeepControlNetOnCPU ? 1 : 0),
keep_vae_on_cpu = (sbyte)(managed.KeepVaeOnCPU ? 1 : 0),
diffusion_flash_attn = (sbyte)(managed.FlashAttention ? 1 : 0),
tae_preview_only = (sbyte)(managed.TaePreviewOnly ? 1 : 0),
diffusion_conv_direct = (sbyte)(managed.DiffusionConvDirect ? 1 : 0),
vae_conv_direct = (sbyte)(managed.VaeConvDirect ? 1 : 0),
force_sdxl_vae_conv_scale = (sbyte)(managed.ForceSdxlVaeConvScale ? 1 : 0),
Expand Down Expand Up @@ -68,11 +70,13 @@ public static DiffusionModelParameter ConvertToManaged(Native.Types.sd_ctx_param
Quantization = unmanaged.wtype,
RngType = unmanaged.rng_type,
Prediction = unmanaged.prediction,
LoraApplyMode = unmanaged.lora_apply_mode,
OffloadParamsToCPU = unmanaged.offload_params_to_cpu == 1,
KeepClipOnCPU = unmanaged.keep_clip_on_cpu == 1,
KeepControlNetOnCPU = unmanaged.keep_control_net_on_cpu == 1,
KeepVaeOnCPU = unmanaged.keep_vae_on_cpu == 1,
FlashAttention = unmanaged.diffusion_flash_attn == 1,
TaePreviewOnly = unmanaged.tae_preview_only == 1,
DiffusionConvDirect = unmanaged.diffusion_conv_direct == 1,
VaeConvDirect = unmanaged.vae_conv_direct == 1,
ForceSdxlVaeConvScale = unmanaged.force_sdxl_vae_conv_scale == 1,
Expand Down
22 changes: 22 additions & 0 deletions StableDiffusion.NET/Native/Native.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace StableDiffusion.NET;
using sd_log_level_t = LogLevel;
using sd_type_t = Quantization;
using sd_vid_gen_params_t = VideoGenerationParameter;
using lora_apply_mode_t = LoraApplyMode;
using preview_t = Preview;
using size_t = nuint;
using uint32_t = uint;
using uint8_t = byte;
Expand Down Expand Up @@ -73,11 +75,13 @@ internal struct sd_ctx_params_t
public sd_type_t wtype;
public rng_type_t rng_type;
public prediction_t prediction;
public lora_apply_mode_t lora_apply_mode;
public sbyte offload_params_to_cpu;
public sbyte keep_clip_on_cpu;
public sbyte keep_control_net_on_cpu;
public sbyte keep_vae_on_cpu;
public sbyte diffusion_flash_attn;
public sbyte tae_preview_only;
public sbyte diffusion_conv_direct;
public sbyte vae_conv_direct;
public sbyte force_sdxl_vae_conv_scale;
Expand Down Expand Up @@ -188,6 +192,7 @@ internal struct upscaler_ctx_t;

internal delegate void sd_log_cb_t(sd_log_level_t level, [MarshalAs(UnmanagedType.LPStr)] string text, void* data);
internal delegate void sd_progress_cb_t(int step, int steps, float time, void* data);
internal delegate void sd_preview_cb_t(int step, int frame_count, sd_image_t* frames, bool is_noisy);

#endregion

Expand All @@ -199,6 +204,9 @@ internal struct upscaler_ctx_t;
[LibraryImport(LIB_NAME, EntryPoint = "sd_set_progress_callback")]
internal static partial void sd_set_progress_callback(sd_progress_cb_t cb, void* data);

[LibraryImport(LIB_NAME, EntryPoint = "sd_set_preview_callback")]
internal static partial void sd_set_preview_callback(sd_preview_cb_t? cb, preview_t mode, int interval, [MarshalAs(UnmanagedType.I1)] bool denoised, [MarshalAs(UnmanagedType.I1)] bool noisy);

[LibraryImport(LIB_NAME, EntryPoint = "get_num_physical_cores")]
internal static partial int32_t get_num_physical_cores();

Expand Down Expand Up @@ -243,6 +251,20 @@ internal struct upscaler_ctx_t;
[LibraryImport(LIB_NAME, EntryPoint = "str_to_prediction")]
internal static partial prediction_t str_to_prediction([MarshalAs(UnmanagedType.LPStr)] string str);

[LibraryImport(LIB_NAME, EntryPoint = "sd_preview_name")]
[return: MarshalAs(UnmanagedType.LPStr)]
internal static partial string sd_preview_name(preview_t preview);

[LibraryImport(LIB_NAME, EntryPoint = "str_to_preview")]
internal static partial preview_t str_to_preview([MarshalAs(UnmanagedType.LPStr)] string str);

[LibraryImport(LIB_NAME, EntryPoint = "sd_lora_apply_mode_name")]
[return: MarshalAs(UnmanagedType.LPStr)]
internal static partial string sd_lora_apply_mode_name(lora_apply_mode_t mode);

[LibraryImport(LIB_NAME, EntryPoint = "str_to_lora_apply_mode")]
internal static partial lora_apply_mode_t str_to_lora_apply_mode([MarshalAs(UnmanagedType.LPStr)] string str);

//

[LibraryImport(LIB_NAME, EntryPoint = "sd_ctx_params_init")]
Expand Down
32 changes: 30 additions & 2 deletions StableDiffusion.NET/StableDiffusionCpp.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System;
using HPPH;
using HPPH;
using JetBrains.Annotations;
using System;

namespace StableDiffusion.NET;

Expand All @@ -12,6 +12,7 @@ public static unsafe class StableDiffusionCpp
// ReSharper disable NotAccessedField.Local - They are important, the delegate can be collected if it's not stored!
private static Native.sd_log_cb_t? _logCallback;
private static Native.sd_progress_cb_t? _progressCallback;
private static Native.sd_preview_cb_t? _previewCallback;
// ReSharper restore NotAccessedField.Local

#endregion
Expand All @@ -20,6 +21,7 @@ public static unsafe class StableDiffusionCpp

public static event EventHandler<StableDiffusionLogEventArgs>? Log;
public static event EventHandler<StableDiffusionProgressEventArgs>? Progress;
public static event EventHandler<StableDiffusionPreviewEventArgs>? Preview;

#endregion

Expand All @@ -33,6 +35,19 @@ public static void InitializeEvents()
Native.sd_set_progress_callback(_progressCallback = OnNativeProgress, null);
}

public static void EnablePreview(Preview mode, int interval, bool denoised, bool noisy)
{
ArgumentOutOfRangeException.ThrowIfNegative(interval);

if (mode == NET.Preview.None)
_previewCallback = null;

else if (_previewCallback == null)
_previewCallback = OnPreview;

Native.sd_set_preview_callback(_previewCallback, mode, interval, denoised, noisy);
}

public static void Convert(string modelPath, string vaePath, Quantization quantization, string outputPath, string tensorTypeRules = "")
{
ArgumentException.ThrowIfNullOrWhiteSpace(nameof(modelPath));
Expand Down Expand Up @@ -89,5 +104,18 @@ private static void OnNativeProgress(int step, int steps, float time, void* data
catch { /**/ }
}

private static void OnPreview(int step, int frameCount, Native.Types.sd_image_t* frames, bool isNoisy)
{
try
{
if (frameCount <= 0 || frames == null) return;

Image<ColorRGB> image = ImageHelper.GetImage(frames, 0);

Preview?.Invoke(null, new StableDiffusionPreviewEventArgs(step, isNoisy, image));
}
catch { /**/ }
}

#endregion
}