# Externe Bibliotheken werden eingeladen
import argparse
import math
import numpy as np
from skimage.io import imread, imsave


### Die Argumente die mittels Konsole übergeben werden, werden in Variablen geschrieben
parser = argparse.ArgumentParser(description='Perform Median Cut Color Quantization on image.')
parser.add_argument('-c', '--colors', type=int, help='Number of colors needed in power of 2, ex: for 16 colors pass 4 because 2^4 = 16')
parser.add_argument('-i', '--input', type=str, help='path of the image to be quantized')
parser.add_argument('-o', '--output', type=str, help='output path for the quantized image')

args = parser.parse_args()

colors = args.colors
output_path = args.output
input_path = args.input
print("reducing the image to {} color palette".format(int(math.pow(2, colors))))



### Das zu verarbeitende Bild wird über die externe Methode "imread" eingelesen.
sample_img = imread(input_path)


### Diese Funktion wird immer zuletzt aufgerufen, um die tatsächlichen Änderungen am Bild zu machen ("sample_img"). 
def median_cut_quantize(img, img_arr):
    r_average = np.mean(img_arr[:, 0])
    g_average = np.mean(img_arr[:, 1])
    b_average = np.mean(img_arr[:, 2])

    for data in img_arr:
        sample_img[data[3]][data[4]] = [r_average, g_average, b_average]



def split_into_buckets(img, img_arr, depth):
    if len(img_arr) == 0:
        return

    if depth == 0:
        median_cut_quantize(img, img_arr)
        return

    r_range = np.max(img_arr[:, 0]) - np.min(img_arr[:, 0])
    g_range = np.max(img_arr[:, 1]) - np.min(img_arr[:, 1])
    b_range = np.max(img_arr[:, 2]) - np.min(img_arr[:, 2])

    space_with_highest_range = 0

    if g_range >= r_range and g_range >= b_range:
        space_with_highest_range = 1
    elif b_range >= r_range and b_range >= g_range:
        space_with_highest_range = 2
    elif r_range >= b_range and r_range >= g_range:
        space_with_highest_range = 0

    # Sortiert das img_arr von niedrigen zu hohen Werten im Bezug auf den zuvor gefundenen Farbraum (R,G oder B) mit der größten Spanne an Werten.
    # space_with_highest_range = {0,1 oder 2}
    # Und berechnet dann den Index des Medians, um das img_arr danach an diesem Index zu trennen.
    img_arr = img_arr[img_arr[:, space_with_highest_range].argsort()]
    median_index = int((len(img_arr) + 1) / 2)

    split_into_buckets(img, img_arr[0:median_index], depth - 1)
    split_into_buckets(img, img_arr[median_index:], depth - 1)




def split_into_buckets_alternative(img, img_arr, depth):
    if depth == 0:
        print("Das Bild wäre einfarbig")
        return

    ### 2 Platzhalter werden definiert. Diese werden zu Arrays und halten den aktuellen Stand der "Buckets"
    oldBuckets = [0]
    newBuckets = [0]
    newBuckets[0] = img_arr

    for currentDepth in range(depth):
        ### Neuberechnung von Kontrollvariablen
        oldBuckets = newBuckets
        currentNumberOfBuckets = 2**(currentDepth+1)

        ### Buckets neu sortieren
        for index, bucket in enumerate(newBuckets):
            r_range = np.max(bucket[:, 0]) - np.min(bucket[:, 0])
            g_range = np.max(bucket[:, 1]) - np.min(bucket[:, 1])
            b_range = np.max(bucket[:, 2]) - np.min(bucket[:, 2])

            space_with_highest_range = 0

            if g_range >= r_range and g_range >= b_range:
                space_with_highest_range = 1
            elif b_range >= r_range and b_range >= g_range:
                space_with_highest_range = 2
            elif r_range >= b_range and r_range >= g_range:
                space_with_highest_range = 0

            # Sortiert das img_arr von niedrigen zu hohen Werten im Bezug auf 
            # den zuvor gefundenen Farbraum (R,G oder B) mit der größten Spanne an Werten.
            # space_with_highest_range = {0,1 oder 2}
            newBuckets[index] = bucket[bucket[:, space_with_highest_range].argsort()]
        
        ### Buckets trennen
        newBuckets = [0] * currentNumberOfBuckets
        for bucketNumber in range(currentNumberOfBuckets):
            #Den Trennindex berechnen
            #Anfang-Mitte und Mitte bis Ende, des jeweiligen Buckets
            if bucketNumber%2 == 0:
                old_index = 0
                new_index = int((len(oldBuckets[int(bucketNumber/2)]) + 1) / 2 )
            else:
                old_index = new_index
                new_index = len(oldBuckets[int(bucketNumber/2)])
            #Den Bucket belegen
            newBuckets[bucketNumber]= oldBuckets[int(bucketNumber/2)][old_index:new_index]

        ### Die Stoppbedingung
        if currentDepth == depth-1:
            for bucket in newBuckets:
                median_cut_quantize(img, bucket)
            return


### Das zu bearbeitende Bild wird hier zu einem 1-Dimensionalen Numpy Array umgeschrieben. Weil es die Bearbeitung leichter macht.
flattened_img_array = []
for rindex, rows in enumerate(sample_img):
    for cindex, color in enumerate(rows):
        flattened_img_array.append([color[0], color[1], color[2], rindex, cindex])

flattened_img_array = np.array(flattened_img_array)


### Hier wird die zuvor implementierte Logik aufgerufen
### Rekursiv
# split_into_buckets(sample_img, flattened_img_array, colors)
### Iterativ
split_into_buckets_alternative(sample_img, flattened_img_array, colors)

### Und das bearbeitete Bild wird dann als neues Bild abgespeichert
imsave(output_path, sample_img)