summaryrefslogtreecommitdiff
path: root/lib/api/server.py
blob: 26efa521e317ac11664b3f192f1bba4750be1130 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Authors: see git history
#
# Copyright (c) 2010 Authors
# Licensed under the GNU GPL version 3.0 or later.  See the file LICENSE for details.

import errno
import logging
import socket
import sys
import time
from threading import Thread
from contextlib import closing

import requests
from flask import Flask, g
from werkzeug.serving import make_server

from ..utils.json import InkStitchJSONProvider
from .simulator import simulator
from .stitch_plan import stitch_plan
from .preferences import preferences
from .page_specs import page_specs
from .lang import languages
# this for electron axios
from flask_cors import CORS


class APIServer(Thread):
    def __init__(self, *args, **kwargs):
        self.extension = args[0]
        Thread.__init__(self, *args[1:], **kwargs)
        self.daemon = True
        self.app = None
        self.host = None
        self.port = None
        self.ready = False

        self.__setup_app()
        self.flask_server = None
        self.server_thread = None

    def __setup_app(self):  # noqa: C901
        # Disable warning about using a development server in a production environment
        cli = sys.modules['flask.cli']
        cli.show_server_banner = lambda *x: None

        self.app = Flask(__name__)
        CORS(self.app)
        self.app.json = InkStitchJSONProvider(self.app)

        self.app.register_blueprint(simulator, url_prefix="/simulator")
        self.app.register_blueprint(stitch_plan, url_prefix="/stitch_plan")
        self.app.register_blueprint(preferences, url_prefix="/preferences")
        self.app.register_blueprint(page_specs, url_prefix="/page_specs")
        self.app.register_blueprint(languages, url_prefix="/languages")

        @self.app.before_request
        def store_extension():
            # make the InkstitchExtension object available to the view handling
            # this request
            g.extension = self.extension

        @self.app.route('/ping')
        def ping():
            return "pong"

    def stop(self):
        self.flask_server.shutdown()
        self.server_thread.join()

    def disable_logging(self):
        logging.getLogger('werkzeug').setLevel(logging.ERROR)

    # https://github.com/aluo-x/Learning_Neural_Acoustic_Fields/blob/master/train.py
    # https://github.com/pytorch/pytorch/issues/71029
    def find_free_port(self):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
            s.bind(('localhost', 0))
            return s.getsockname()[1]

    def run(self):
        self.disable_logging()

        self.host = "127.0.0.1"
        self.port = self.find_free_port()
        self.flask_server = make_server(self.host, self.port, self.app)
        self.server_thread = Thread(target=self.flask_server.serve_forever)
        self.server_thread.start()

    def ready_checker(self):
        """Wait until the server is started.

        Annoyingly, there's no way to get a callback to be run when the Flask
        server starts.  Instead, we'll have to poll.
        """

        while True:
            if self.port:
                try:
                    response = requests.get("http://%s:%s/ping" % (self.host, self.port))
                    if response.status_code == 200:
                        break
                except socket.error as e:
                    if e.errno == errno.ECONNREFUSED:
                        pass
                    else:
                        raise

            time.sleep(0.1)

    def start_server(self):
        """Start the API server.

        returns: port (int) -- the port that the server is listening on
                   (on localhost)
        """

        checker = Thread(target=self.ready_checker)
        checker.start()
        self.start()
        checker.join()

        return self.port