#!/usr/bin/env python

from __future__ import print_function

"""
pipcolours - extract all colours present in source image.
Produces a PNG that has each colour exactly once.

Also performs grey and alpha reduction when possible:
if all colour pixels are grey, output PNG is grey;
if all pixels are opaque, output PNG has no alpha channel.
"""

import argparse
import collections
import itertools
import sys

import png


def png_colours(out, inp):
    r = png.Reader(file=inp)
    _, _, rows, info = r.asDirect()

    rows, info = unique_colours(rows, info)

    writer = png.Writer(**info)
    writer.write(out, rows)


def jasc_counts(out, inp):
    r = png.Reader(file=inp)
    _, _, rows, info = r.asDirect()

    print("JASC-PAL", file=out)
    # A version field? (the internet is unclear)
    print("0100", file=out)

    colour_dict = count_colours(rows, info)

    print(len(colour_dict), file=out)
    for colour, count in colour_dict.items():
        print(*itertools.chain(colour, ["#", count]), file=out)


def jasc_colours(out, inp):
    r = png.Reader(file=inp)
    _, _, rows, info = r.read()

    print("JASC-PAL", file=out)
    # A version field? (the internet is unclear)
    print("0100", file=out)

    if "palette" in info:
        print(len(info['palette']), file=out)
        for pixel in info['palette']:
            print(*pixel, file=out)
        return

    rows, info = unique_colours(rows, info)
    planes = info['planes']
    bitdepth = info['bitdepth']

    row, = rows
    n = len(row) // planes
    print(n, file=out)
    for pixel in png.group(row, planes):
        print(*pixel, file=out)


def count_colours(rows, info):
    planes = info["planes"]
    bitdepth = info["bitdepth"]

    count = collections.Counter()
    for row in rows:
        for pixel in png.group(row, planes):
            count[pixel] += 1

    return count


def unique_colours(rows, info):
    planes = info["planes"]
    bitdepth = info["bitdepth"]

    col = set()
    for row in rows:
        col = col.union(png.group(row, planes))

    col, planes = channel_reduce(col, planes, bitdepth)

    col = sorted(col)

    width = len(col)
    info = dict(
        width=width,
        height=1,
        bitdepth=bitdepth,
        greyscale=planes in (1, 2),
        alpha=planes in (2, 4),
        planes=planes,
    )
    row = list(itertools.chain(*col))
    return [row], info


def channel_reduce(col, planes, bitdepth):
    """Attempt to reduce the number of channels in the set of colours."""
    col, planes = reduce_grey(col, planes)
    col, planes = reduce_alpha(col, planes, bitdepth)
    return col, planes


def reduce_grey(col, planes):
    """
    Reduce a colour image to grey if
    all intensities match in all pixels.
    """
    if planes >= 3:

        def isgrey(c):
            return c[0] == c[1] == c[2]

        if all(isgrey(c) for c in col):
            # Every colour is grey, convert to 1- or 2-tuples.
            col = set(x[0::3] for x in col)
            planes -= 2
    return col, planes


def reduce_alpha(col, planes, bitdepth):
    """
    Remove alpha channel if all pixels are fully opaque.
    """
    maxval = 2 ** bitdepth - 1
    if planes in (2, 4):

        def isopaque(c):
            return c[-1] == maxval

        if all(isopaque(c) for c in col):
            # Every pixel is opaque, remove alpha channel.
            col = set(x[:-1] for x in col)
            planes -= 1
    return col, planes


def main(argv=None):
    if argv is None:
        argv = sys.argv

    argv = argv[1:]
    parser = argparse.ArgumentParser()
    version = "%(prog)s " + png.__version__
    parser.add_argument("--version", action="version", version=version)
    parser.add_argument("--jasc", action="store_true")
    parser.add_argument("--count", action="store_true")
    parser.add_argument("input", default="-", metavar="PNG")

    args = parser.parse_args(argv)
    input = png.cli_open(args.input)

    if args.count:
        return jasc_counts(sys.stdout, input)
    if args.jasc:
        return jasc_colours(sys.stdout, input)

    out = png.binary_stdout()
    return png_colours(out, input)


if __name__ == "__main__":
    main()
