#include <math.h>
#include <stdlib.h>
#include <string.h>

#include "scale.h"

#define ENABLE_DEBUG_PRINTF 0

#if ENABLE_DEBUG_PRINTF
#include <stdio.h>
#endif

#define SUBPIXEL_ACCURACY  8
#define SUBPIXEL_FACTOR    (1 << SUBPIXEL_ACCURACY)
#define ROUND_ADJUST       (SUBPIXEL_FACTOR / 2)

#define SF_X_UPSCALE   (1 << 30)
#define SF_X_DOWNSCALE (1 << 29)
#define SF_X_OUTSIDE   (1 << 28)
#define SF_Y_UPSCALE   (1 << 27)
#define SF_Y_DOWNSCALE (1 << 26)
#define SF_PUBLIC_FLAGS (SF_Y_DOWNSCALE - 1)

#if ENABLE_DEBUG_PRINTF
#define Dprintf(x) printf x
#else
#define Dprintf(x) do { } while(0)
#endif

static void fill_border(scale_image_struct* s, unsigned char *buf, int pixels) {
    int bpp;
    unsigned long color;
    while (pixels--) {
        color = s->border_color;
        bpp = s->bpp;
        while (bpp--) {
            *buf++ = (unsigned char) color;
            color >>= 8;
        }
    }
}

static unsigned char* get_line(scale_image_struct* s, int y) {
    const int stride = s->bpp;
    int buf_id, extra, x, subcnt;
    unsigned char *buf, *pos;
    const unsigned char *src;

    /* get effective Y */
    if (!(s->flags & SF_USE_BORDER_COLOR)) {
        if (y < 0) y = 0;
        if (y >= s->src_height) y = s->src_height - 1;
    }

    /* fetch line from cache or allocate new cache line */
    if (s->lbuf_y[0] == y) return s->lbuf_data[0];
    if (s->lbuf_y[1] == y) return s->lbuf_data[1];
    buf_id = (abs(s->lbuf_y[0] - y) > abs(s->lbuf_y[1] - y)) ? 0 : 1;
    buf = s->lbuf_data[buf_id];
    s->lbuf_y[buf_id] = y;

    /* fill complete line if necessary */
    if ((s->flags & SF_USE_BORDER_COLOR) && ((y < 0) || (y >= s->src_height) || (s->flags & SF_X_OUTSIDE))) {
        fill_border(s, buf, s->ix1 - s->ix0);
        return buf;
    }

    /* copy core area */
    src = &(s->src[y * s->src_line_stride + s->lbuf_core_x0 * s->src_pix_stride]);
    pos = &buf[s->lbuf_core_offset * stride];
    if (s->flags & SF_USE_SHUFFLE) {
        unsigned long pattern;
        extra = s->src_pix_stride;
        for (x = s->lbuf_core_size;  x;  --x) {
            pattern = s->shuffle_pattern;
            for (subcnt = stride;  subcnt;  --subcnt, pattern >>= 4)
                *pos++ = src[pattern & 15];
            src += extra;
        }
    } else {
        extra = s->src_pix_stride - stride;
        if (extra) {
            for (x = s->lbuf_core_size;  x;  --x) {
                for (subcnt = stride;  subcnt;  --subcnt)
                    *pos++ = *src++;
                src += extra;
            }
        } else
            memcpy(pos, src, s->lbuf_core_size * stride);
    }

    /* pad to the left */
    if (s->lbuf_core_offset) {
        if (s->flags & SF_USE_BORDER_COLOR)
            fill_border(s, buf, s->lbuf_core_offset);
        else {
            for (x = s->lbuf_core_offset * stride - 1;  x >= 0;  --x)
                buf[x] = buf[x + stride];
        }
    }

    /* pad to the right */
    extra = s->lbuf_core_offset + s->lbuf_core_size;
    pos = &buf[extra * stride];
    extra = s->ix1 - s->ix0 - extra;
    if (extra > 0) {
        if (s->flags & SF_USE_BORDER_COLOR)
            fill_border(s, pos, extra);
        else {
            extra *= stride;
            for (x = 0;  x < extra;  ++x)
                pos[x] = pos[x - stride];
        }
    }

    return buf;
}

static inline void clear_acc(scale_image_struct* s) {
    memset(s->acc, 0, s->lbuf_size * sizeof(int));
    s->acc_sum = 0;
}

static void add_line(scale_image_struct* s, unsigned char* line, int weight) {
    int i;
    for (i = 0;  i < s->lbuf_size;  ++i)
        s->acc[i] += line[i] * weight;
    s->acc_sum += weight;
}

static void upscale_x(scale_image_struct* s, const signed int* acc, unsigned char* dest) {
    const int unit = s->xunit, wmax = (unit + ROUND_ADJUST) >> SUBPIXEL_ACCURACY;
    int x, w, fpos = s->xipos;
    const int div = wmax * s->acc_sum;
    signed int left, right;
    left = acc[0];
    right = acc[s->bpp];
    for (x = s->dest_width;  x;  --x) {
        while (fpos > unit) {
            left = right;
            acc += s->bpp;
            right = acc[s->bpp];
            fpos -= unit;
        }
        w = (fpos + ROUND_ADJUST) >> SUBPIXEL_ACCURACY;
        *dest = (unsigned char) (((wmax - w) * left + w * right + (div >> 1)) / div);
        dest += s->dest_pix_stride;
        fpos += s->xstep;
    }
}

static void downscale_x(scale_image_struct* s, const signed int* acc, unsigned char* dest) {
    const int unit = s->xunit;
    int x, w, fpos = s->xipos, fend;
    signed int pix_acc, pix_sum;
    for (x = s->dest_width;  x;  --x) {
        pix_acc = pix_sum = 0;
        fend = fpos + s->xstep;
        while (fend > unit) {
            w = (unit - fpos + ROUND_ADJUST) >> SUBPIXEL_ACCURACY;
            pix_acc += *acc * w;
            pix_sum += s->acc_sum * w;
            fpos = 0;
            fend -= unit;
            acc += s->bpp;
        }
        w = (fend - fpos + ROUND_ADJUST) >> SUBPIXEL_ACCURACY;
        pix_acc += *acc * w;
        pix_sum += s->acc_sum * w;
        *dest = (unsigned char) (pix_acc / pix_sum);
        dest += s->dest_pix_stride;
        fpos = fend;
    }
}

static void scale_x(scale_image_struct* s) {
    int x, byte;
    /* X upscale */
    if (s->flags & SF_X_UPSCALE) {
        for (byte = 0;  byte < s->bpp;  ++byte)
            upscale_x(s, s->acc + byte, s->dest + byte);
    }
    /* X downscale */
    else if (s->flags & SF_X_DOWNSCALE) {
        int shift = 0;
        while ((s->acc_sum >> shift) > s->xlimit) ++shift;
        s->acc_sum >>= shift;
        for (x = 0;  x < s->lbuf_size;  ++x)
            s->acc[x] >>= shift;
        for (byte = 0;  byte < s->bpp;  ++byte)
            downscale_x(s, s->acc + byte, s->dest + byte);
    }
    /* X copy */
    else {
        const signed int *acc = s->acc;
        unsigned char *dest = s->dest;
        for (x = s->dest_width;  x;  --x) {
            for (byte = 0;  byte < s->bpp;  ++byte)
                dest[byte] = (unsigned char) (*acc++ / s->acc_sum);
            dest += s->dest_pix_stride;
        }
    }
    s->dest += s->dest_line_stride;
}

static void scale_x_fast(scale_image_struct* s, const unsigned char *buf) {
    int x, byte, pos;
    const unsigned char *src;
    unsigned char *dest;
    /* X scale */
    if (s->flags & (SF_X_UPSCALE | SF_X_DOWNSCALE)) {
        for (byte = 0;  byte < s->bpp;  ++byte) {
            src = buf + byte;
            dest = s->dest + byte;
            pos = s->xipos;
            for (x = s->dest_width;  x;  --x) {
                while (pos > s->xunit) {
                    pos -= s->xunit;
                    src += s->src_pix_stride;
                }
                *dest = *src;
                pos += s->xstep;
                dest += s->dest_pix_stride;
            }
        }
    }
    /* X copy */
    else {
        src = buf;
        dest = s->dest;
        for (x = s->dest_width;  x;  --x) {
            for (byte = 0;  byte < s->bpp;  ++byte)
                dest[byte] = *src++;
            dest += s->dest_pix_stride;
        }
    }
    s->dest += s->dest_line_stride;
}

int scale_image_ex(scale_image_struct* s) {
    int temp;

    /* basic sanity checks */
    if (!s || !s->src || !s->dest
    || (s->src_line_stride < 1) || (s->src_width < 1) || (s->src_height < 1)
    || (s->dest_line_stride < 1) || (s->dest_width < 1) || (s->dest_height < 1)
    || (s->bpp < 1) || (s->src_pix_stride < 1) || (s->bpp > s->dest_pix_stride)
    || (!(s->flags & SF_USE_SHUFFLE) && (s->bpp > s->src_pix_stride)))
        return 0;

    /* check width limits */
    temp = s->src_line_stride / s->src_pix_stride;
    if (s->src_width > temp) return 0;
    temp = s->dest_line_stride / s->dest_pix_stride;
    if (s->dest_width > temp) return 0;

    /* swap coordinates */
    if (s->flags & (SF_USE_SOURCE_RECT & SF_USE_DEST_RECT)) {
        if (s->x0 > s->x1) { float t = s->x0; s->x0 = s->x1; s->x1 = t; }
        if (s->y0 > s->y1) { float t = s->y0; s->y0 = s->y1; s->y1 = t; }
    }

    /* get source rectangle */
    if (s->flags & SF_USE_DEST_RECT) {
        float f;
        Dprintf(("dest rect: %.2f, %.2f -> %.2f, %.2f\n", s->x0, s->y0, s->x1, s->y1));
        f = ((float) (s->dest_width * s->src_width)) / (s->x1 - s->x0);
        s->x0 = -(s->x0 * f) / s->dest_width;
        s->x1 = s->x0 + f;
        f = ((float) (s->dest_height * s->src_height)) / (s->y1 - s->y0);
        s->y0 = -(s->y0 * f) / s->dest_height;
        s->y1 = s->y0 + f;
    } else if (!(s->flags & SF_USE_SOURCE_RECT)) {
        s->x0 = s->y0 = 0.0f;
        s->x1 = s->src_width;
        s->y1 = s->src_height;
    }

    /* compute destination width/height */
    s->fwidth = s->x1 - s->x0;
    s->fheight = s->y1 - s->y0;
    if ((s->fwidth < (1.0 / 256)) || (s->fheight < (1.0 / 256)))
        return 0;

    /* determine integer source rectangle */
    s->ix0 = (int) floorf(s->x0);
    s->iy0 = (int) floorf(s->y0);
    s->ix1 = (int) ceilf(s->x1);
    s->iy1 = (int) ceilf(s->y1);

    /* determine scale mode flags */
    s->flags &= SF_PUBLIC_FLAGS;
         if (s->dest_width > s->fwidth) s->flags |= SF_X_UPSCALE;
    else if (s->dest_width < s->fwidth) s->flags |= SF_X_DOWNSCALE;
    else if (s->ix0 != s->x0)           s->flags |= SF_X_DOWNSCALE;
         if (s->dest_height > s->fheight) s->flags |= SF_Y_UPSCALE;
    else if (s->dest_height < s->fheight) s->flags |= SF_Y_DOWNSCALE;
    else if (s->iy0 != s->y0)             s->flags |= SF_Y_DOWNSCALE;

    /* expand source rectangle for upscaling */
    if (s->flags & SF_X_UPSCALE) { s->ix0--; s->ix1++; }
    if (s->flags & SF_Y_UPSCALE) { s->iy0--; s->iy1++; }
    if ((s->ix1 <= 0) || (s->ix0 >= s->src_width)) s->flags |= SF_X_OUTSIDE;

    /* allocate line buffers */
    s->lbuf_core_size = s->ix1 - s->ix0;
    s->lbuf_size = s->lbuf_core_size * s->bpp;
    s->lbuf_data[0] = malloc(2 * s->lbuf_size);
    if (!s->lbuf_data[0]) return 0;
    s->lbuf_data[1] = s->lbuf_data[0] + s->lbuf_size;
    s->lbuf_y[0] = s->lbuf_y[1] = 0x40000000;
    s->acc = malloc(s->lbuf_size * sizeof(int));
    if (!s->acc) { free(s->lbuf_data[0]); return 0; }

    /* compute line buffer core area */
    s->lbuf_core_offset = 0;
    s->lbuf_core_x0 = s->ix0;
    if (s->lbuf_core_x0 < 0) {
        s->lbuf_core_offset -= s->lbuf_core_x0;
        s->lbuf_core_size += s->lbuf_core_x0;
        s->lbuf_core_x0 = 0;
    }
    temp = s->lbuf_core_x0 + s->lbuf_core_size - s->src_width;
    if (temp > 0) s->lbuf_core_size -= temp;
    if (!(s->flags & SF_USE_BORDER_COLOR) && (s->lbuf_core_size <= 0)) {
        s->lbuf_core_size = 1;
        if (s->lbuf_core_x0) s->lbuf_core_x0 = s->src_width - 1;
    }

    /* compute scaling properties */
    s->xunit = s->dest_width * SUBPIXEL_FACTOR;
    s->xstep = (int) floorf(s->fwidth * SUBPIXEL_FACTOR + 0.5f);
    s->xipos = (int) floorf((s->x0 - s->ix0) * s->xunit + 0.5f);
    if (s->flags & SF_X_UPSCALE)
        s->xipos += (s->xstep - s->xunit + 1) >> 1;
    s->xlimit = 0x007FFFFF / ((s->xstep + SUBPIXEL_FACTOR) >> SUBPIXEL_ACCURACY);
    s->yunit = s->dest_height * SUBPIXEL_FACTOR;
    s->ystep = (int) floorf(s->fheight * SUBPIXEL_FACTOR + 0.5f);
    s->yipos = (int) floorf((s->y0 - s->iy0) * s->yunit + 0.5f);
    if (s->flags & SF_Y_UPSCALE)
        s->yipos += (s->ystep - s->yunit + 1) >> 1;
    for (temp = 1;  (1 << temp) < s->dest_height;  ++temp);
    for (;  (1 << temp) < s->src_height;  ++temp);
    s->ybits = temp;

#if ENABLE_DEBUG_PRINTF
    Dprintf(("source: %dx%d, pix_stride=%d, line_stride=%d\n", s->src_width, s->src_height, s->src_pix_stride, s->src_line_stride));
    Dprintf(("dest: %dx%d, pix_stride=%d, line_stride=%d\n", s->dest_width, s->dest_height, s->dest_pix_stride, s->dest_line_stride));
    Dprintf(("bpp=%d, border_color=0x%08X\n", s->bpp, (int) s->border_color));
    Dprintf(("frect: %.2f, %.2f -> %.2f, %.2f; fsize: %.2f x %.2f\n", s->x0, s->y0, s->x1, s->y1, s->fwidth, s->fheight));
    Dprintf(("irect: %d, %d -> %d, %d\n", s->ix0, s->iy0, s->ix1, s->iy1));
    Dprintf(("flags:"));
    if (s->flags & SF_USE_BORDER_COLOR) Dprintf((" use_border"));
    if (s->flags & SF_X_UPSCALE)    Dprintf((" X=upscale"));
    if (s->flags & SF_X_DOWNSCALE)  Dprintf((" X=downscale"));
    if (s->flags & SF_X_OUTSIDE)    Dprintf((" X-out"));
    if (s->flags & SF_Y_UPSCALE)    Dprintf((" Y=upscale"));
    if (s->flags & SF_Y_DOWNSCALE)  Dprintf((" Y=downscale"));
    Dprintf(("\n"));
    Dprintf(("lbuf: size=%d, core_size=%d, core_offset=%d, core_x0=%d\n", s->lbuf_size, s->lbuf_core_size, s->lbuf_core_offset, s->lbuf_core_x0));
    Dprintf(("Xscale: unit=%d, step=%d, ipos=%d\n", s->xunit, s->xstep, s->xipos));
    Dprintf(("Yscale: unit=%d, step=%d, ipos=%d, ybits=%d\n", s->yunit, s->ystep, s->yipos, s->ybits));
#endif

    /* Y low-quality scaling */
    if (s->flags & SF_LOW_QUALITY) {
        int y, sy = s->iy0, pos = s->yipos;
        for (y = 0;  y < s->dest_height;  ++y) {
            while (pos > s->yunit) {
                ++sy;
                pos -= s->yunit;
            }
            scale_x_fast(s, get_line(s, sy));
            pos += s->ystep;
        }
    }
    /* Y upscale */
    else if (s->flags & SF_Y_UPSCALE) {
        const int wmax = (s->yunit + ((1 << s->ybits) >> 1)) >> s->ybits;
        int y, w, sy = s->iy0, fpos = s->yipos;
        for (y = s->dest_height;  y;  --y) {
            while (fpos > s->yunit) {
                ++sy;
                fpos -= s->yunit;
            }
            clear_acc(s);
            w = (fpos + ((1 << s->ybits) >> 1)) >> s->ybits;
            add_line(s, get_line(s, sy+0), wmax - w);
            add_line(s, get_line(s, sy+1), w);
            scale_x(s);
            fpos += s->ystep;
        }
    }
    /* Y downscale */
    else if (s->flags & SF_Y_DOWNSCALE) {
        int fpos = s->yipos, fend;
        int y, sy = s->iy0;
        for (y = s->dest_height;  y;  --y) {
            clear_acc(s);
            fend = fpos + s->ystep;
            while (fend > s->yunit) {
                add_line(s, get_line(s, sy++), (s->yunit - fpos + ((1 << s->ybits) >> 1)) >> s->ybits);
                fpos = 0;
                fend -= s->yunit;
            }
            add_line(s, get_line(s, sy), (fend - fpos + ((1 << s->ybits) >> 1)) >> s->ybits);
            scale_x(s);
            fpos = fend;
        }
    }
    /* Y copy */
    else {
        int y;
        for (y = 0;  y < s->dest_height;  ++y) {
            unsigned char *buf = get_line(s, y + s->iy0);
            if (s->flags & (SF_X_UPSCALE | SF_X_DOWNSCALE)) {
                clear_acc(s);
                add_line(s, buf, 1);
                scale_x(s);
            } else if (s->bpp != s->dest_pix_stride) {
                const int extra = s->dest_pix_stride - s->bpp;
                int x, subcnt;
                unsigned char *dest = s->dest;
                for (x = s->dest_width;  x;  --x) {
                    for (subcnt = s->bpp;  subcnt;  --subcnt)
                        *dest++ = *buf++;
                    dest += extra;
                }
                s->dest += s->dest_line_stride;
            } else {
                memcpy(s->dest, buf, s->dest_width * s->bpp);
                s->dest += s->dest_line_stride;
            }
        }
    }

    /* done. */
    if (s->lbuf_data[0]) free(s->lbuf_data[0]);
    if (s->acc) free(s->acc);
    return 1;
}

void scale_set_zoom(scale_image_struct* s,
    float src_cx, float src_cy, float src_ar,
    float dest_cx, float dest_cy, float dest_ar,
    float zoom
) {
    float xsize, ysize;

    /* sanity check */
    if (!s || (s->src_width < 1) || (s->src_height < 1) || (s->dest_width < 1) || (s->dest_height < 1))
        return;
    Dprintf(("set zoom: %.2f,%.2f -> %.2f,%.2f; zoom = %.2f\n", src_cx, src_cy, dest_cx, dest_cy, zoom));

    /* convert DAR to PAR and merge them to a single AR value*/
    if (src_ar < 0.0f) src_ar = (-src_ar * s->src_height) / s->src_width;
    if (dest_ar < 0.0f) dest_ar = (-dest_ar * s->dest_height) / s->dest_width;
    Dprintf(("source PAR: %.5f; dest PAR: %.5f; ", src_ar, dest_ar));
    src_ar /= dest_ar;
    Dprintf(("combined PAR: %.5f\n", src_ar));

    /* compute full-screen X/Y size */
    ysize = s->dest_height;
    xsize = s->src_width * src_ar * ysize / s->src_height;
    Dprintf(("full screen size: %.2f x %.2f\n", xsize, ysize));
    if ((xsize > s->dest_width) ^ (zoom < 0.0f)) {
        xsize = s->dest_width;
        ysize = s->src_height * xsize / (s->src_width * src_ar);
    }
    Dprintf(("full screen size: %.2f x %.2f\n", xsize, ysize));

    /* apply zoom */
    zoom = fabs(zoom);
    xsize *= zoom;
    ysize *= zoom;
    Dprintf(("zoomed size: %.2f x %.2f\n", xsize, ysize));

    /* compute absolute center position */
    dest_cx *= s->dest_width;
    dest_cy *= s->dest_height;

    /* set destination rectangle */
    s->x0 = dest_cx -         src_cx  * xsize;
    s->y0 = dest_cy -         src_cy  * ysize;
    s->x1 = dest_cx + (1.0f - src_cx) * xsize;
    s->y1 = dest_cy + (1.0f - src_cy) * ysize;
    s->flags = (s->flags & ~SF_USE_SOURCE_RECT) | SF_USE_DEST_RECT;
}
