liteshort.py 15 KB
Newer Older
1
# Copyright (c) 2019 Steven Spangler <132@ikl.sh>, Kevin Alberts <kevin@kevinalberts.nl>
2 3 4
# This file is part of liteshort by 132ikl
# This software is license under the MIT license. It should be included in your copy of this software.
# A copy of the MIT license can be obtained at https://mit-license.org/
5
import ipaddress
6

7
from flask import Flask, current_app, flash, g, jsonify, make_response, redirect, render_template, request, \
8
    send_from_directory, url_for, session, abort
9
import bcrypt
132ikl's avatar
132ikl committed
10
import os
11 12 13
import random
import sqlite3
import time
132ikl's avatar
132ikl committed
14
import urllib
15 16
import yaml

132ikl's avatar
132ikl committed
17 18
app = Flask(__name__)

19 20

def load_config():
21 22 23 24 25
    new_config = yaml.load(open('config.yml'))
    new_config = {k.lower(): v for k, v in new_config.items()}  # Make config keys case insensitive

    req_options = {'admin_username': 'admin', 'database_name': "urls", 'random_length': 4,
                   'allowed_chars': 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_',
26
                   'random_gen_timeout': 5, 'site_name': 'liteshort', 'site_domain': None, 'show_github_link': True,
27
                   'secret_key': None, 'disable_api': False, 'subdomain': '', 'latest': 'l', 'admin_links_per_page': 20
28 29 30
                   }

    config_types = {'admin_username': str, 'database_name': str, 'random_length': int,
132ikl's avatar
132ikl committed
31
                    'allowed_chars': str, 'random_gen_timeout': int, 'site_name': str,
32
                    'site_domain': (str, type(None)), 'show_github_link': bool, 'secret_key': str,
33 34
                    'disable_api': bool, 'subdomain': (str, type(None)), 'latest': (str, type(None)),
                    'admin_links_per_page': int
132ikl's avatar
132ikl committed
35
                    }
36 37 38 39 40 41 42

    for option in req_options.keys():
        if option not in new_config.keys():  # Make sure everything in req_options is set in config
            new_config[option] = req_options[option]

    for option in new_config.keys():
        if option in config_types:
132ikl's avatar
132ikl committed
43 44 45 46 47 48 49 50
            matches = False
            if type(config_types[option]) is not tuple:
                config_types[option] = (config_types[option],)  # Automatically creates tuple for non-tuple types
            for req_type in config_types[option]:  # Iterates through tuple to allow multiple types for config options
                if type(new_config[option]) is req_type:
                    matches = True
            if not matches:
                raise TypeError(option + " is incorrect type")
132ikl's avatar
132ikl committed
51
    if not new_config['disable_api']:
52
        if 'admin_hashed_password' in new_config.keys() and new_config['admin_hashed_password']:
132ikl's avatar
132ikl committed
53
            new_config['password_hashed'] = True
54
        elif 'admin_password' in new_config.keys() and new_config['admin_password']:
132ikl's avatar
132ikl committed
55 56 57
            new_config['password_hashed'] = False
        else:
            raise TypeError('admin_password or admin_hashed_password must be set in config.yml')
58 59 60
    return new_config


132ikl's avatar
132ikl committed
61 62 63 64 65 66 67
def authenticate(username, password):
    return username == current_app.config['admin_username'] and check_password(password, current_app.config)


def check_long_exist(long):
    query = query_db('SELECT short FROM urls WHERE long = ?', (long,))
    for i in query:
68
        if i and (len(i['short']) <= current_app.config["random_length"]) and i['short'] != current_app.config['latest']:  # Checks if query if pre-existing URL is same as random length URL
132ikl's avatar
132ikl committed
69 70 71 72
            return i['short']
    return False


132ikl's avatar
132ikl committed
73 74
def check_short_exist(short):  # Allow to also check against a long link
    if get_long(short):
132ikl's avatar
132ikl committed
75 76 77 78
        return True
    return False


79 80 81 82 83 84 85 86 87
def check_password(password, pass_config):
    if pass_config['password_hashed']:
        return bcrypt.checkpw(password.encode('utf-8'), pass_config['admin_hashed_password'].encode('utf-8'))
    elif not pass_config['password_hashed']:
        return password == pass_config['admin_password']
    else:
        raise RuntimeError('This should never occur! Bailing...')


132ikl's avatar
132ikl committed
88 89 90 91 92
def delete_url(deletion):
    result = query_db('SELECT * FROM urls WHERE short = ?', (deletion,), False, None)  # Return as tuple instead of row
    get_db().cursor().execute('DELETE FROM urls WHERE short = ?', (deletion,))
    get_db().commit()
    return len(result)
93 94


132ikl's avatar
132ikl committed
95 96 97 98 99 100 101 102 103 104 105 106 107 108
def dict_factory(cursor, row):
    d = {}
    for idx, col in enumerate(cursor.description):
        d[col[0]] = row[idx]
    return d


def generate_short(rq):
    timeout = time.time() + current_app.config['random_gen_timeout']
    while True:
        if time.time() >= timeout:
            return response(rq, None, 'Timeout while generating random short URL')
        short = ''.join(random.choice(current_app.config['allowed_chars'])
                        for i in range(current_app.config['random_length']))
109
        if not check_short_exist(short) and short != app.config['latest']:
132ikl's avatar
132ikl committed
110 111 112
            return short


132ikl's avatar
132ikl committed
113 114 115 116 117 118 119
def get_long(short):
    row = query_db('SELECT long FROM urls WHERE short = ?', (short,), True)
    if row and row['long']:
        return row['long']
    return None


120 121 122 123 124 125 126 127
def get_baseUrl():
    if current_app.config['site_domain']:
        # TODO: un-hack-ify adding the protocol here
        return 'https://' + current_app.config['site_domain'] + '/'
    else:
        return request.base_url


132ikl's avatar
132ikl committed
128 129 130 131 132 133
def list_shortlinks():
    result = query_db('SELECT * FROM urls', (), False, None)
    result = nested_list_to_dict(result)
    return result


134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
def list_shortlinks_page(page, limit=50):
    assert page >= 1
    assert type(page) == int
    assert type(limit) == int
    start_index = (page - 1) * limit
    result = query_db('SELECT * FROM urls ORDER BY short LIMIT ? OFFSET ?', (limit, start_index), False, None)
    return result


def get_num_pages(limit=50):
    assert type(limit) == int
    result = query_db('SELECT COUNT(*) FROM urls', (), False, None)
    items = result[0][0]
    page_count = result[0][0] // limit
    if page_count * limit < items:
        page_count += 1
    return page_count, items


132ikl's avatar
132ikl committed
153 154 155
def nested_list_to_dict(l):
    d = {}
    for nl in l:
156
        d[nl[0]] = nl[1]
132ikl's avatar
132ikl committed
157 158 159 160
    return d


def response(rq, result, error_msg="Error: Unknown error"):
161 162 163 164 165 166 167 168 169 170 171 172 173

    # Only allow responses to whitelisted IP ranges
    whitelist_disabled = 'allowed_ip_ranges' not in app.config or app.config['allowed_ip_ranges'] is None
    if not whitelist_disabled:
        is_allowed = False
        for range in app.config['allowed_ip_ranges']:
            network = ipaddress.ip_network(range)
            remote = ipaddress.ip_address(request.remote_addr)
            if remote in network:
                is_allowed = True
        if not is_allowed:
            abort(403)

174
    if rq.form.get('api') and not rq.form.get('format') == 'json':
132ikl's avatar
132ikl committed
175
        return "Format type HTML (default) not support for API"  # Future-proof for non-json return types
176
    if rq.form.get('format') == 'json':
132ikl's avatar
132ikl committed
177 178 179 180 181 182 183 184 185 186
        # If not result provided OR result doesn't exist, send error
        # Allows for setting an error message with explicitly checking in regular code
        if result:
            if result is True:  # Allows sending with no result (ie. during deletion)
                return jsonify(success=True)
            else:
                return jsonify(success=True, result=result)
        else:
            return jsonify(success=False, error=error_msg)
    else:
132ikl's avatar
132ikl committed
187 188 189 190 191
        if result and result is not True:
            flash(result, 'success')
        elif not result:
            flash(error_msg, 'error')
        return render_template("main.html")
132ikl's avatar
132ikl committed
192 193


194 195 196 197 198 199 200 201 202 203
def set_latest(long):
    if app.config['latest']:
        if query_db('SELECT short FROM urls WHERE short = ?', (current_app.config['latest'],)):
            get_db().cursor().execute("UPDATE urls SET long = ? WHERE short = ?",
                                      (long, current_app.config['latest']))
        else:
            get_db().cursor().execute("INSERT INTO urls (long,short) VALUES (?, ?)",
                                      (long, current_app.config['latest']))


132ikl's avatar
132ikl committed
204
def validate_short(short):
205 206 207
    if short == app.config['latest']:
        return response(request, None,
                        'Short URL cannot be the same as a special URL ({})'.format(short))
132ikl's avatar
132ikl committed
208 209 210 211 212 213
    for char in short:
        if char not in current_app.config['allowed_chars']:
            return response(request, None,
                            'Character ' + char + ' not allowed in short URL')
    return True

214

132ikl's avatar
132ikl committed
215 216 217
def validate_long(long):  # https://stackoverflow.com/a/36283503
    token = urllib.parse.urlparse(long)
    return all([token.scheme, token.netloc])
218

132ikl's avatar
132ikl committed
219
# Database connection functions
220 221 222 223 224 225 226 227 228 229 230 231


def get_db():
    if 'db' not in g:
        g.db = sqlite3.connect(
            ''.join((current_app.config['database_name'], '.db')),
            detect_types=sqlite3.PARSE_DECLTYPES
        )
        g.db.cursor().execute('CREATE TABLE IF NOT EXISTS urls (long,short)')
    return g.db


132ikl's avatar
132ikl committed
232 233
def query_db(query, args=(), one=False, row_factory=sqlite3.Row):
    get_db().row_factory = row_factory
234 235 236 237 238 239
    cur = get_db().execute(query, args)
    rv = cur.fetchall()
    cur.close()
    return (rv[0] if rv else None) if one else rv


132ikl's avatar
132ikl committed
240 241 242 243
@app.teardown_appcontext
def close_db(error):
    if hasattr(g, 'sqlite_db'):
        g.sqlite_db.close()
244 245


132ikl's avatar
132ikl committed
246
app.config.update(load_config())  # Add YAML config to Flask config
132ikl's avatar
132ikl committed
247
app.secret_key = app.config['secret_key']
248
app.config['SERVER_NAME'] = app.config['site_domain']
249 250


251 252 253 254 255
@app.errorhandler(403)
def access_denied(e):
    return render_template("403_access_denied.html")


132ikl's avatar
132ikl committed
256 257 258 259 260 261
@app.route('/favicon.ico', subdomain=app.config['subdomain'])
def favicon():
    return send_from_directory(os.path.join(app.root_path, 'static'),
                               'favicon.ico', mimetype='image/vnd.microsoft.icon')


262
@app.route('/', subdomain=app.config['subdomain'])
263
def main():
132ikl's avatar
132ikl committed
264 265 266 267 268 269 270
    return response(request, True)


@app.route('/<url>')
def main_redir(url):
    long = get_long(url)
    if long:
271 272 273 274 275 276
        resp = make_response(redirect(long, 301))
    else:
        flash('Short URL "' + url + '" doesn\'t exist', 'error')
        resp = make_response(redirect(url_for('main')))
    resp.headers.set('Cache-Control', 'no-store, must-revalidate')
    return resp
277 278


279
@app.route('/', methods=['POST'], subdomain=app.config['subdomain'])
280
def main_post():
281
    if request.form.get('long'):
132ikl's avatar
132ikl committed
282 283
        if not validate_long(request.form['long']):
            return response(request, None, "Long URL is not valid")
284
        if request.form.get('short'):
132ikl's avatar
132ikl committed
285 286 287 288 289 290
            # Validate long as URL and short custom text against allowed characters
            result = validate_short(request.form['short'])
            if validate_short(request.form['short']) is True:
                short = request.form['short']
            else:
                return result
132ikl's avatar
132ikl committed
291
            if get_long(short) == request.form['long']:
292
                return response(request, get_baseUrl() + short,
132ikl's avatar
132ikl committed
293
                                'Error: Failed to return pre-existing non-random shortlink')
294
        else:
132ikl's avatar
132ikl committed
295
            short = generate_short(request)
132ikl's avatar
132ikl committed
296
        if check_short_exist(short):
132ikl's avatar
132ikl committed
297
            return response(request, None,
132ikl's avatar
132ikl committed
298
                            'Short URL already taken')
299
        long_exists = check_long_exist(request.form['long'])
300
        if long_exists and not request.form.get('short'):
301 302 303
            set_latest(request.form['long'])
            get_db().commit()
            return response(request, get_baseUrl() + long_exists,
132ikl's avatar
132ikl committed
304 305
                            'Error: Failed to return pre-existing random shortlink')
        get_db().cursor().execute('INSERT INTO urls (long,short) VALUES (?,?)', (request.form['long'], short))
306
        set_latest(request.form['long'])
132ikl's avatar
132ikl committed
307
        get_db().commit()
308 309

        return response(request, get_baseUrl() + short,
132ikl's avatar
132ikl committed
310
                        'Error: Failed to generate')
311
    elif request.form.get('api'):
132ikl's avatar
132ikl committed
312 313
        if current_app.config['disable_api']:
            return response(request, None, "API is disabled.")
132ikl's avatar
132ikl committed
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
        # All API calls require authentication
        if not request.authorization \
                or not authenticate(request.authorization['username'], request.authorization['password']):
            return response(request, None, "BaiscAuth failed")
        command = request.form['api']
        if command == 'list' or command == 'listshort':
            return response(request, list_shortlinks(), "Failed to list items")
        elif command == 'listlong':
            shortlinks = list_shortlinks()
            shortlinks = {v: k for k, v in shortlinks.items()}
            return response(request, shortlinks, "Failed to list items")
        elif command == 'delete':
            deleted = 0
            if 'long' not in request.form and 'short' not in request.form:
                return response(request, None, "Provide short or long in POST data")
            if 'short' in request.form:
                deleted = delete_url(request.form['short']) + deleted
            if 'long' in request.form:
                deleted = delete_url(request.form['long']) + deleted
            if deleted > 0:
                return response(request, "Deleted " + str(deleted) + " URLs")
            else:
                return response(request, None, "Failed to delete URL")
        else:
            return response(request, None, 'Command ' + command + ' not found')
339
    else:
132ikl's avatar
132ikl committed
340
        return response(request, None, 'Long URL required')
341 342


343 344
@app.route('/login', methods=['POST'])
def login():
345 346
    if ('admin_hashed_password' not in app.config or app.config['admin_hashed_password'] is None) and (
            'admin_password' not in app.config or app.config['admin_password'] is None):
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
        raise AssertionError("Login is disabled.")

    if authenticate(request.form['username'], request.form['password']):
        session['logged_in'] = True
    else:
        flash('Wrong password!', 'error')
    return make_response(redirect(url_for('admin')))


@app.route('/logout')
def logout():
    if 'logged_in' in session and session['logged_in']:
        session['logged_in'] = False
    return make_response(redirect(url_for('admin')))


@app.route('/delete/<short>')
def delete(short):
    if 'logged_in' in session and session['logged_in']:
        success = delete_url(short)
        if success:
            flash("Link '/{}' deleted.".format(short), "success")
        else:
            flash("Failed to delete URL.".format(short), "error")
        page = request.args.get('page', '1')
        return make_response(redirect(url_for('admin')+"?page="+page))
    else:
        return make_response(redirect(url_for('admin')))


@app.route('/admin')
def admin():
    if 'logged_in' in session and session['logged_in']:
        page_count, num_items = get_num_pages(app.config['admin_links_per_page'])
        page = request.args.get('page', '1')
        try:
            page = int(page)
        except ValueError:
            page = 1
386 387 388 389 390
        if page_count != 0:
                if page > page_count:
                    return make_response(redirect(url_for('admin')+"?page="+str(page_count)))
                if page < 1:
                    return make_response(redirect(url_for('admin')+"?page=1"))
391 392

        urls = list_shortlinks_page(page, app.config['admin_links_per_page'])
393

394 395 396 397 398
        return render_template('admin.html', urls=urls, page=page, page_count=page_count, num_items=num_items)
    else:
        return render_template('login.html')


399 400
if __name__ == '__main__':
    app.run()