summaryrefslogtreecommitdiff
path: root/lib/threads/palette.py
blob: f1ff6cb4ed92947e1533ee9f40179cbe48037de3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Authors: see git history
#
# Copyright (c) 2010 Authors
# Licensed under the GNU GPL version 3.0 or later.  See the file LICENSE for details.

from collections.abc import Set

from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie1994
from colormath.color_objects import LabColor, sRGBColor

from .color import ThreadColor


def compare_thread_colors(color1, color2):
    # K_L=2 indicates textiles
    return delta_e_cie1994(color1, color2, K_L=2)


class ThreadPalette(Set):
    """Holds a set of ThreadColors all from the same manufacturer."""

    def __init__(self, palette_file):
        self.threads = dict()
        self.parse_palette_file(palette_file)

    def parse_palette_file(self, palette_file):
        """Read a GIMP palette file and load thread colors.

        Example file:

        GIMP Palette
        Name: Ink/Stitch: Metro
        Columns: 4
        # RGB Value                                 Color Name Number
        240     186     212                         Sugar Pink   1624
        237     171     194                           Carnatio   1636

        """

        with open(palette_file) as palette:
            line = palette.readline().strip()
            if line.lower() != "gimp palette":
                raise ValueError("Invalid gimp palette header")

            self.name = palette.readline().strip()
            if self.name.lower().startswith('name: ink/stitch: '):
                self.name = self.name[18:]

            # number of columns
            palette.readline()

            # headers
            palette.readline()

            for line in palette:
                try:
                    fields = line.split(None, 3)
                    thread_color = [int(field) for field in fields[:3]]
                    thread_name, thread_number = fields[3].strip().rsplit(" ", 1)
                    thread_name = thread_name.strip()

                    thread = ThreadColor(thread_color, thread_name, thread_number, manufacturer=self.name)
                    self.threads[thread] = convert_color(sRGBColor(*thread_color, is_upscaled=True), LabColor)
                except ValueError:
                    continue

    def __contains__(self, thread):
        return thread in self.threads

    def __iter__(self):
        return iter(self.threads)

    def __len__(self):
        return len(self.threads)

    def nearest_color(self, color):
        """Find the thread in this palette that looks the most like the specified color."""

        if isinstance(color, ThreadColor):
            color = color.rgb

        color = convert_color(sRGBColor(*color, is_upscaled=True), LabColor)

        return min(self, key=lambda thread: compare_thread_colors(self.threads[thread], color))