@@ -601,21 +601,31 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
601601typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
602602
603603// Tiling
604- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
604+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
605605 output = ggml_set_f32 (output, 0 );
606606
607607 int input_width = (int )input->ne [0 ];
608608 int input_height = (int )input->ne [1 ];
609609 int output_width = (int )output->ne [0 ];
610610 int output_height = (int )output->ne [1 ];
611+
612+ int input_tile_size, output_tile_size;
613+ if (scaled_out) {
614+ input_tile_size = tile_size;
615+ output_tile_size = tile_size * scale;
616+ } else {
617+ input_tile_size = tile_size * scale;
618+ output_tile_size = tile_size;
619+ }
620+
611621 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
612622
613- int tile_overlap = (int32_t )(tile_size * tile_overlap_factor);
614- int non_tile_overlap = tile_size - tile_overlap;
623+ int tile_overlap = (int32_t )(input_tile_size * tile_overlap_factor);
624+ int non_tile_overlap = input_tile_size - tile_overlap;
615625
616626 struct ggml_init_params params = {};
617- params.mem_size += tile_size * tile_size * input->ne [2 ] * sizeof (float ); // input chunk
618- params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne [2 ] * sizeof (float ); // output chunk
627+ params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
628+ params.mem_size += output_tile_size * output_tile_size * output->ne [2 ] * sizeof (float ); // output chunk
619629 params.mem_size += 3 * ggml_tensor_overhead ();
620630 params.mem_buffer = NULL ;
621631 params.no_alloc = false ;
@@ -630,8 +640,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
630640 }
631641
632642 // tiling
633- ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size, tile_size , input->ne [2 ], 1 );
634- ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale , output->ne [2 ], 1 );
643+ ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size , input->ne [2 ], 1 );
644+ ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size , output->ne [2 ], 1 );
635645 on_processing (input_tile, NULL , true );
636646 int num_tiles = ceil ((float )input_width / non_tile_overlap) * ceil ((float )input_height / non_tile_overlap);
637647 LOG_INFO (" processing %i tiles" , num_tiles);
@@ -640,19 +650,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
640650 bool last_y = false , last_x = false ;
641651 float last_time = 0 .0f ;
642652 for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap) {
643- if (y + tile_size >= input_height) {
644- y = input_height - tile_size ;
653+ if (y + input_tile_size >= input_height) {
654+ y = input_height - input_tile_size ;
645655 last_y = true ;
646656 }
647657 for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap) {
648- if (x + tile_size >= input_width) {
649- x = input_width - tile_size ;
658+ if (x + input_tile_size >= input_width) {
659+ x = input_width - input_tile_size ;
650660 last_x = true ;
651661 }
652662 int64_t t1 = ggml_time_ms ();
653663 ggml_split_tensor_2d (input, input_tile, x, y);
654664 on_processing (input_tile, output_tile, false );
655- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
665+ if (scaled_out) {
666+ ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
667+ } else {
668+ ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap / scale);
669+ }
656670 int64_t t2 = ggml_time_ms ();
657671 last_time = (t2 - t1) / 1000 .0f ;
658672 pretty_progress (tile_count, num_tiles, last_time);
0 commit comments