Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2811,7 +2811,8 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
{request->width / request->vae_scale_factor,
request->height / request->vae_scale_factor,
1,
1});
1},
sd::ops::InterpolateMode::NearestMax);

sd::Tensor<float> init_latent;
sd::Tensor<float> control_latent;
Expand Down Expand Up @@ -2956,9 +2957,13 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
latents.ref_latents = std::move(ref_latents);

if (sd_version_is_inpaint(sd_ctx->sd->version)) {
latents.denoise_mask = std::move(latent_mask);
}

latent_mask = sd::ops::maxPool2D(latent_mask,
{3, 3},
{1, 1},
{1, 1});
}
latents.denoise_mask = std::move(latent_mask);

return latents;
}

Expand Down
172 changes: 163 additions & 9 deletions src/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ namespace sd {
return data_.at(static_cast<size_t>(index));
}

int64_t get_flat_index(const std::vector<int64_t>& coord) const {
return static_cast<int64_t>(offset_of(coord));
}

private:
size_t offset_of(const std::vector<int64_t>& coord) const {
if (coord.size() != shape_.size()) {
Expand Down Expand Up @@ -815,6 +819,9 @@ namespace sd {
namespace ops {
enum class InterpolateMode {
Nearest,
NearestMax,
NearestMin,
NearestAvg,
};

inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) {
Expand Down Expand Up @@ -1012,12 +1019,16 @@ namespace sd {
std::vector<int64_t> output_shape,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
bool is_nearest_like_mode = (mode == InterpolateMode::Nearest ||
mode == InterpolateMode::NearestMax ||
mode == InterpolateMode::NearestMin ||
mode == InterpolateMode::NearestAvg);
if (!is_nearest_like_mode) {
tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
Expand All @@ -1044,14 +1055,82 @@ namespace sd {
}
}

bool pure_upsampling = true;
for(int64_t i=0; i<input.dim(); ++i) {
if (input.shape()[i] > output_shape[i]) pure_upsampling = false;
}

Tensor<T> output(std::move(output_shape));
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
if (!pure_upsampling && (mode != InterpolateMode::Nearest)) {
// Pooling modes only differ from nearest mode when downsampling
for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());

std::vector<int64_t> input_start(output.dim(), 0);
std::vector<int64_t> input_end(output.dim(), 0);

for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
int64_t I_dim = input.shape()[i];
int64_t O_dim = output.shape()[i];

if (I_dim > 0 && O_dim > 0) {
input_start[i] = std::max(int64_t(0), static_cast<int64_t>(output_coord[i] * I_dim / O_dim));
input_end[i] = std::min(I_dim, ((output_coord[i] + 1) * I_dim + O_dim - 1) / O_dim);
} else {
input_start[i] = 0;
input_end[i] = 1;
}
}

T val;
if (mode == InterpolateMode::NearestMax) {
val = std::numeric_limits<T>::lowest();
} else if(mode == InterpolateMode::NearestMin) {
val = std::numeric_limits<T>::max();
} else if(mode == InterpolateMode::NearestAvg) {
val = T(0);
}

bool done_window = false;
std::vector<int64_t> current_in_coord = input_start;

while (!done_window) {
if (mode == InterpolateMode::NearestMax) {
val = std::max(val, input.index(current_in_coord));
} else if(mode == InterpolateMode::NearestMin) {
val = std::min(val, input.index(current_in_coord));
} else if(mode == InterpolateMode::NearestAvg) {
val += input.index(current_in_coord);
}

for (int d = static_cast<int>(output.dim()) - 1; d >= 0; --d) {
if (++current_in_coord[d] < input_end[d]) {
break;
}
current_in_coord[d] = input_start[d];
if (d == 0) {
done_window = true;
}
}
}
if (mode == InterpolateMode::NearestAvg) {
int64_t window_size = 1;
for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
window_size *= (input_end[i] - input_start[i]);
}
val /= static_cast<T>(window_size);
}
output[flat_out] = val;
}
} else {
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
}
output[flat] = input.index(input_coord);
}
output[flat] = input.index(input_coord);
}

return output;
Expand Down Expand Up @@ -1128,6 +1207,81 @@ namespace sd {
align_corners);
}

template <typename T>
inline Tensor<T> maxPool2D(const Tensor<T>& input,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding) {
if (input.dim() != 4) {
tensor_throw_invalid_argument("Tensor maxPool2D requires 4D input: input_dim=" +
std::to_string(input.dim()) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) {
tensor_throw_invalid_argument("Tensor maxPool2D requires kernel_size, stride, and padding to have length 2");
}
for (size_t i = 0; i < 2; ++i) {
if (kernel_size[i] <= 0) {
tensor_throw_invalid_argument("Tensor maxPool2D kernel_size must be positive: kernel_size=" +
tensor_shape_to_string(kernel_size));
}
if (stride[i] <= 0) {
tensor_throw_invalid_argument("Tensor maxPool2D stride must be positive: stride=" +
tensor_shape_to_string(stride));
}
if (padding[i] < 0) {
tensor_throw_invalid_argument("Tensor maxPool2D padding must be non-negative: padding=" +
tensor_shape_to_string(padding));
}
}

const int64_t in_height = input.shape()[0];
const int64_t in_width = input.shape()[1];
const int64_t in_channels = input.shape()[2];
const int64_t batch_size = input.shape()[3];

const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1;
const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1;

if (out_height <= 0 || out_width <= 0) {
tensor_throw_invalid_argument("maxPool2D results in invalid output dimensions: " +
std::to_string(out_height) + "x" + std::to_string(out_width));
}

Tensor<T> output({out_height, out_width, in_channels, batch_size});

for (int64_t oh = 0; oh < out_height; ++oh) {
for (int64_t ow = 0; ow < out_width; ++ow) {
for (int64_t c = 0; c < in_channels; ++c) {
for (int64_t b = 0; b < batch_size; ++b) {
T max_val = std::numeric_limits<T>::lowest();
bool has_valid_input = false;

for (int64_t kh = 0; kh < kernel_size[0]; ++kh) {
for (int64_t kw = 0; kw < kernel_size[1]; ++kw) {
int64_t ih = oh * stride[0] + kh - padding[0];
int64_t iw = ow * stride[1] + kw - padding[1];

if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
T val = input.index(ih, iw, c, b);
max_val = std::max(max_val, val);
has_valid_input = true;
}
}
}

if (has_valid_input) {
output.index(oh, ow, c, b) = max_val;
} else {
output.index(oh, ow, c, b) = T(0);
}
}
}
}
}
return output;
}

template <typename T>
inline Tensor<T> concat(const Tensor<T>& lhs, const Tensor<T>& rhs, size_t dim) {
if (lhs.dim() != rhs.dim()) {
Expand Down
Loading