Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions Header/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,39 @@ extern "C" {
enum rng_type_t {
STD_DEFAULT_RNG,
CUDA_RNG,
CPU_RNG,
RNG_TYPE_COUNT
};

enum sample_method_t {
SAMPLE_METHOD_DEFAULT,
EULER,
HEUN,
DPM2,
DPMPP2S_A,
DPMPP2M,
DPMPP2Mv2,
IPNDM,
IPNDM_V,
LCM,
DDIM_TRAILING,
TCD,
EULER_A,
EULER_SAMPLE_METHOD,
EULER_A_SAMPLE_METHOD,
HEUN_SAMPLE_METHOD,
DPM2_SAMPLE_METHOD,
DPMPP2S_A_SAMPLE_METHOD,
DPMPP2M_SAMPLE_METHOD,
DPMPP2Mv2_SAMPLE_METHOD,
IPNDM_SAMPLE_METHOD,
IPNDM_V_SAMPLE_METHOD,
LCM_SAMPLE_METHOD,
DDIM_TRAILING_SAMPLE_METHOD,
TCD_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT
};

enum scheduler_t {
DEFAULT,
DISCRETE,
KARRAS,
EXPONENTIAL,
AYS,
GITS,
SGM_UNIFORM,
SIMPLE,
SMOOTHSTEP,
SCHEDULE_COUNT

DISCRETE_SCHEDULER,
KARRAS_SCHEDULER,
EXPONENTIAL_SCHEDULER,
AYS_SCHEDULER,
GITS_SCHEDULER,
SGM_UNIFORM_SCHEDULER,
SIMPLE_SCHEDULER,
SMOOTHSTEP_SCHEDULER,
LCM_SCHEDULER,
SCHEDULER_COUNT
};

enum prediction_t {
Expand Down Expand Up @@ -166,11 +168,13 @@ typedef struct {
const char* lora_model_dir;
const char* embedding_dir;
const char* photo_maker_path;
const char* tensor_type_rules;
bool vae_decode_only;
bool free_params_immediately;
int n_threads;
enum sd_type_t wtype;
enum rng_type_t rng_type;
enum rng_type_t sampler_rng_type;
enum prediction_t prediction;
enum lora_apply_mode_t lora_apply_mode;
bool offload_params_to_cpu;
Expand Down Expand Up @@ -226,6 +230,13 @@ typedef struct {
float style_strength;
} sd_pm_params_t; // photo maker

typedef struct {
bool enabled;
float reuse_threshold;
float start_percent;
float end_percent;
} sd_easycache_params_t;

typedef struct {
const char* prompt;
const char* negative_prompt;
Expand All @@ -246,6 +257,7 @@ typedef struct {
float control_strength;
sd_pm_params_t pm_params;
sd_tiling_params_t vae_tiling_params;
sd_easycache_params_t easycache;
} sd_img_gen_params_t;

typedef struct {
Expand All @@ -265,6 +277,7 @@ typedef struct {
int64_t seed;
int video_frames;
float vace_strength;
sd_easycache_params_t easycache;
} sd_vid_gen_params_t;

typedef struct sd_ctx_t sd_ctx_t;
Expand All @@ -285,25 +298,30 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type);
SD_API enum rng_type_t str_to_rng_type(const char* str);
SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
SD_API enum sample_method_t str_to_sample_method(const char* str);
SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_schedule(const char* str);
SD_API const char* sd_scheduler_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_scheduler(const char* str);
SD_API const char* sd_prediction_name(enum prediction_t prediction);
SD_API enum prediction_t str_to_prediction(const char* str);
SD_API const char* sd_preview_name(enum preview_t preview);
SD_API enum preview_t str_to_preview(const char* str);
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);

SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params);

SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);

SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);

SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);

SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
Expand Down Expand Up @@ -342,4 +360,4 @@ SD_API bool preprocess_canny(sd_image_t image,
}
#endif

#endif // __STABLE_DIFFUSION_H__
#endif // __STABLE_DIFFUSION_H__
3 changes: 2 additions & 1 deletion StableDiffusion.NET/Enums/RngType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
public enum RngType
{
Standard,
Cuda
Cuda,
Cpu
}
4 changes: 2 additions & 2 deletions StableDiffusion.NET/Enums/Sampler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

public enum Sampler
{
Default,
Euler,
Euler_A,
Heun,
DPM2,
DPMPP2SA,
Expand All @@ -14,5 +14,5 @@ public enum Sampler
LCM,
DDIM_Trailing,
TCD,
Euler_A,
Default
}
5 changes: 3 additions & 2 deletions StableDiffusion.NET/Enums/Schedule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

public enum Scheduler
{
Default,
Discrete,
Karras,
Exponential,
AYS,
GITS,
SGM_Uniform,
Simple,
Smoothstep
Smoothstep,
LCM,
Default
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ public sealed class DiffusionModelParameter
/// </summary>
public RngType RngType { get; set; } = RngType.Standard;

public RngType SamplerRngType { get; set; } = RngType.Standard;

public Prediction Prediction { get; set; } = Prediction.Default;

public LoraApplyMode LoraApplyMode { get; set; } = LoraApplyMode.Auto;
Expand All @@ -110,6 +112,8 @@ public sealed class DiffusionModelParameter
/// </summary>
public string StackedIdEmbeddingsDirectory { get; set; } = string.Empty;

public string TensorTypeRules { get; set; } = string.Empty;

/// <summary>
/// path to full model
/// </summary>
Expand Down
11 changes: 11 additions & 0 deletions StableDiffusion.NET/Models/Parameter/EasyCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace StableDiffusion.NET;

public sealed class EasyCache
{
public bool IsEnabled { get; set; }
public float ReuseThreshold { get; set; }
public float StartPercent { get; set; }
public float EndPercent { get; set; }

internal EasyCache() { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public sealed class ImageGenerationParameter

public TilingParameter VaeTiling { get; } = new();

public EasyCache EasyCache { get; } = new();

#endregion

public static ImageGenerationParameter Create() => new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public sealed class VideoGenerationParameter

public float VaceStrength { get; set; }

public EasyCache EasyCache { get; } = new();

#endregion

public static VideoGenerationParameter Create() => new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ public static Native.Types.sd_ctx_params_t ConvertToUnmanaged(DiffusionModelPara
lora_model_dir = AnsiStringMarshaller.ConvertToUnmanaged(managed.LoraModelDirectory),
embedding_dir = AnsiStringMarshaller.ConvertToUnmanaged(managed.EmbeddingsDirectory),
photo_maker_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.StackedIdEmbeddingsDirectory),
tensor_type_rules = AnsiStringMarshaller.ConvertToUnmanaged(managed.TensorTypeRules),
vae_decode_only = (sbyte)(managed.VaeDecodeOnly ? 1 : 0),
free_params_immediately = (sbyte)(managed.FreeParamsImmediately ? 1 : 0),
n_threads = managed.ThreadCount,
wtype = managed.Quantization,
rng_type = managed.RngType,
sampler_rng_type = managed.SamplerRngType,
prediction = managed.Prediction,
lora_apply_mode = managed.LoraApplyMode,
offload_params_to_cpu = (sbyte)(managed.OffloadParamsToCPU ? 1 : 0),
Expand Down Expand Up @@ -64,11 +66,13 @@ public static DiffusionModelParameter ConvertToManaged(Native.Types.sd_ctx_param
LoraModelDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.lora_model_dir) ?? string.Empty,
EmbeddingsDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.embedding_dir) ?? string.Empty,
StackedIdEmbeddingsDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.photo_maker_path) ?? string.Empty,
TensorTypeRules = AnsiStringMarshaller.ConvertToManaged(unmanaged.tensor_type_rules) ?? string.Empty,
VaeDecodeOnly = unmanaged.vae_decode_only == 1,
FreeParamsImmediately = unmanaged.free_params_immediately == 1,
ThreadCount = unmanaged.n_threads,
Quantization = unmanaged.wtype,
RngType = unmanaged.rng_type,
SamplerRngType = unmanaged.sampler_rng_type,
Prediction = unmanaged.prediction,
LoraApplyMode = unmanaged.lora_apply_mode,
OffloadParamsToCPU = unmanaged.offload_params_to_cpu == 1,
Expand Down Expand Up @@ -101,5 +105,6 @@ public static void Free(Native.Types.sd_ctx_params_t unmanaged)
AnsiStringMarshaller.Free(unmanaged.lora_model_dir);
AnsiStringMarshaller.Free(unmanaged.embedding_dir);
AnsiStringMarshaller.Free(unmanaged.photo_maker_path);
AnsiStringMarshaller.Free(unmanaged.tensor_type_rules);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ public static unsafe ImageGenerationParameter ConvertToManaged(Native.Types.sd_i
TargetOverlap = unmanaged.vae_tiling_params.target_overlap,
RelSizeX = unmanaged.vae_tiling_params.rel_size_x,
RelSizeY = unmanaged.vae_tiling_params.rel_size_y
},
EasyCache =
{
IsEnabled = unmanaged.easycache.enabled == 1,
ReuseThreshold = unmanaged.easycache.reuse_threshold,
StartPercent = unmanaged.easycache.start_percent,
EndPercent = unmanaged.easycache.end_percent
}
};

Expand Down Expand Up @@ -126,6 +133,14 @@ public void FromManaged(ImageGenerationParameter managed)
rel_size_y = managed.VaeTiling.RelSizeY
};

Native.Types.sd_easycache_params_t easyCache = new()
{
enabled = (sbyte)(managed.EasyCache.IsEnabled ? 1 : 0),
reuse_threshold = managed.EasyCache.ReuseThreshold,
start_percent = managed.EasyCache.StartPercent,
end_percent = managed.EasyCache.EndPercent,
};

_imgGenParams = new Native.Types.sd_img_gen_params_t
{
prompt = AnsiStringMarshaller.ConvertToUnmanaged(managed.Prompt),
Expand All @@ -145,7 +160,8 @@ public void FromManaged(ImageGenerationParameter managed)
control_image = _controlNetImage,
control_strength = managed.ControlNet.Strength,
pm_params = photoMakerParams,
vae_tiling_params = tilingParams
vae_tiling_params = tilingParams,
easycache = easyCache
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ public static unsafe VideoGenerationParameter ConvertToManaged(Native.Types.sd_v
Strength = unmanaged.strength,
Seed = unmanaged.seed,
FrameCount = unmanaged.video_frames,
VaceStrength = unmanaged.vace_strength
VaceStrength = unmanaged.vace_strength,
EasyCache =
{
IsEnabled = unmanaged.easycache.enabled == 1,
ReuseThreshold = unmanaged.easycache.reuse_threshold,
StartPercent = unmanaged.easycache.start_percent,
EndPercent = unmanaged.easycache.end_percent
}
};

return parameter;
Expand Down Expand Up @@ -68,6 +75,14 @@ public void FromManaged(VideoGenerationParameter managed)
_initImage = managed.InitImage?.ToSdImage() ?? new Native.Types.sd_image_t();
_endImage = managed.EndImage?.ToSdImage() ?? new Native.Types.sd_image_t();
_controlFrames = managed.ControlFrames == null ? null : managed.ControlFrames.ToSdImage();

Native.Types.sd_easycache_params_t easyCache = new()
{
enabled = (sbyte)(managed.EasyCache.IsEnabled ? 1 : 0),
reuse_threshold = managed.EasyCache.ReuseThreshold,
start_percent = managed.EasyCache.StartPercent,
end_percent = managed.EasyCache.EndPercent,
};

_vidGenParams = new Native.Types.sd_vid_gen_params_t
{
Expand All @@ -87,6 +102,7 @@ public void FromManaged(VideoGenerationParameter managed)
seed = managed.Seed,
video_frames = managed.FrameCount,
vace_strength = managed.VaceStrength,
easycache = easyCache,
};
}

Expand Down
Loading