summaryrefslogtreecommitdiff
path: root/lib/threads/catalog.py
blob: 886d9aa42e0fb06acb650596a6be682f3bad0b4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Authors: see git history
#
# Copyright (c) 2010 Authors
# Licensed under the GNU GPL version 3.0 or later.  See the file LICENSE for details.

import os
from collections.abc import Sequence
from glob import glob

from ..utils import get_bundled_dir, guess_inkscape_config_path
from .palette import ThreadPalette


class _ThreadCatalog(Sequence):
    """Holds a set of ThreadPalettes."""

    def __init__(self):
        self.palettes = []
        self.load_palettes(self.get_palettes_paths())

    def get_palettes_paths(self):
        """Creates a list containing the path of two directories:
        1. Palette directory of Inkscape
        2. Palette directory of inkstitch
        """
        path = [os.path.join(guess_inkscape_config_path(), 'palettes')]
        inkstitch_path = get_bundled_dir('palettes')
        path.append(inkstitch_path)

        return path

    def load_palettes(self, paths):
        palettes = []
        for path in paths:
            for palette_file in glob(os.path.join(path, 'InkStitch*.gpl')):
                palette_basename = os.path.basename(palette_file)
                if palette_basename not in palettes:
                    palette = ThreadPalette(palette_file)
                    if not palette.is_gimp_palette:
                        continue
                    self.palettes.append(ThreadPalette(palette_file))
                    palettes.append(palette_basename)

    def palette_names(self):
        return list(sorted(palette.name for palette in self))

    def __getitem__(self, item):
        return self.palettes[item]

    def __len__(self):
        return len(self.palettes)

    def _num_exact_color_matches(self, palette, threads):
        """Number of colors in stitch plan with an exact match in this palette."""

        return sum(1 for thread in threads if thread in palette)

    def match_and_apply_palette(self, stitch_plan, palette=None):
        if palette is None:
            palette = self.match_palette(stitch_plan)
        else:
            palette = self.get_palette_by_name(palette)

        if palette is not None:
            self.apply_palette(stitch_plan, palette)

        return palette

    def match_palette(self, stitch_plan):
        """Figure out which color palette was used

        Scans the catalog of color palettes and chooses one that seems most
        likely to be the one that the user used.  A palette will only be
        chosen if more than 80% of the thread colors in the stitch plan are
        exact matches for threads in the palette.
        """
        if not self.palettes:
            return None

        threads = [color_block.color for color_block in stitch_plan]
        palettes_and_matches = [(palette, self._num_exact_color_matches(palette, threads))
                                for palette in self]
        palette, matches = max(palettes_and_matches, key=lambda item: item[1])

        if matches < 0.8 * len(stitch_plan):
            # if less than 80% of the colors are an exact match,
            # don't use this palette
            return None
        else:
            return palette

    def apply_palette(self, stitch_plan, palette):
        for color_block in stitch_plan:
            if color_block.color.chart:
                # do not overwrite cutwork settings
                continue

        for color_block in stitch_plan:
            nearest = palette.nearest_color(color_block.color)

            color_block.color.name = nearest.name
            color_block.color.number = nearest.number
            color_block.color.manufacturer = nearest.manufacturer
            color_block.color.description = nearest.description

    def get_palette_by_name(self, name):
        for palette in self:
            if palette.name == name:
                return palette


_catalog = None


def ThreadCatalog():
    """Singleton _ThreadCatalog factory"""

    global _catalog
    if _catalog is None:
        _catalog = _ThreadCatalog()

    return _catalog