#!/usr/lib64/linuxfabrik-monitoring-plugins/venv/bin/python
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import os
import sys

import lib.args
import lib.base
import lib.db_mysql
import lib.time
import lib.txt
import lib.version
from lib.globals import STATE_CRIT, STATE_OK, STATE_UNKNOWN

__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026051102'

DESCRIPTION = """Checks MySQL/MariaDB TLS/SSL posture: the current connection itself, the
server's TLS capability (`have_ssl`), enforcement (`require_secure_transport`), enabled TLS
versions (no TLSv1.0/1.1, at least TLSv1.2 or TLSv1.3), the presence of a server certificate
and key, the local expiry of `ssl_cert` and `ssl_ca` files (when readable on the same host),
and any remote accounts that can still connect without `REQUIRE SSL`. Each finding maps to a
copy-pasteable SQL or `openssl` recommendation."""

DEFAULT_DEFAULTS_FILE = '/var/spool/icinga2/.my.cnf'
DEFAULT_DEFAULTS_GROUP = 'client'
DEFAULT_SEVERITY = 'warn'
DEFAULT_TIMEOUT = 3
# mysqltuner uses a hardcoded 30-day cliff for cert expiry. We expose it as a
# Nagios range so an admin can tune the warm-up. Default warn at 30 days mirrors
# mysqltuner; crit at 7 gives a useful escalation step inside the 30-day band.
DEFAULT_WARN_CERT_EXPIRY = '30:'
DEFAULT_CRIT_CERT_EXPIRY = '7:'


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '-c',
        '--critical',
        help='Days until local `ssl_cert` / `ssl_ca` expiry that trigger CRIT. '
        'Supports Nagios ranges. '
        'Example: `--critical=14:`. '
        'Default: %(default)s',
        dest='CRIT',
        default=DEFAULT_CRIT_CERT_EXPIRY,
    )

    parser.add_argument(
        '--defaults-file',
        help='MySQL/MariaDB cnf file to read user, host and password from. '
        'Example: `--defaults-file=/var/spool/icinga2/.my.cnf`. '
        'Default: %(default)s',
        dest='DEFAULTS_FILE',
        default=DEFAULT_DEFAULTS_FILE,
    )

    parser.add_argument(
        '--defaults-group',
        help=lib.args.help('--defaults-group') + ' Default: %(default)s',
        dest='DEFAULTS_GROUP',
        default=DEFAULT_DEFAULTS_GROUP,
    )

    parser.add_argument(
        '--severity',
        help='Severity for the threshold-less TLS findings (session not '
        'encrypted, `have_ssl=DISABLED`, `require_secure_transport=OFF`, weak '
        'TLS versions, missing cert/key, remote users without `REQUIRE SSL`). '
        'One of `warn` or `crit`. '
        'Default: %(default)s',
        dest='SEVERITY',
        default=DEFAULT_SEVERITY,
        choices=['warn', 'crit'],
    )

    parser.add_argument(
        '--timeout',
        help=lib.args.help('--timeout') + ' Default: %(default)s (seconds)',
        dest='TIMEOUT',
        type=int,
        default=DEFAULT_TIMEOUT,
    )

    parser.add_argument(
        '-w',
        '--warning',
        help='Days until local `ssl_cert` / `ssl_ca` expiry that trigger WARN. '
        'Supports Nagios ranges. '
        'Example: `--warning=60:`. '
        'Default: %(default)s',
        dest='WARN',
        default=DEFAULT_WARN_CERT_EXPIRY,
    )

    args, _ = parser.parse_known_args()
    return args


def get_session_cipher(conn):
    """Return the cipher of the current connection. Empty string when the
    connection is not encrypted. Mirrors mysqltuner's `SHOW SESSION STATUS
    LIKE 'Ssl_cipher'` probe.
    """
    sql = "SHOW SESSION STATUS LIKE 'Ssl_cipher'"
    success, row = lib.db_mysql.select(conn, sql, fetchone=True)
    if not success or row is None:
        return ''
    # SHOW STATUS rows come back as {'Variable_name': ..., 'Value': ...}.
    cipher = (row.get('Value') or '').strip()
    if cipher in ('', 'NULL', '0'):
        return ''
    return cipher


def find_remote_users_without_ssl(conn, is_mariadb_10_4_plus):
    """Port of mysqltuner check_remote_user_ssl(). Returns the list
    of accounts that can connect from a non-localhost host without `REQUIRE
    SSL`. MariaDB 10.4+ reads `mysql.global_priv.ssl_type` from JSON;
    earlier versions and MySQL read `mysql.user.ssl_type`.
    """
    if is_mariadb_10_4_plus:
        sql = """
            select concat(quote(user), '@', quote(host)) as user
            from mysql.global_priv
            where host not in ('localhost', '127.0.0.1', '::1')
                and json_value(priv, '$.ssl_type') = ''
            ;
        """
    else:
        sql = """
            select concat(quote(user), '@', quote(host)) as user
            from mysql.user
            where host not in ('localhost', '127.0.0.1', '::1')
                and (ssl_type = 'NONE' or ssl_type = '')
            ;
        """
    success, users = lib.db_mysql.select(conn, sql)
    return users if success else []


def cert_days_until_expiry(path):
    """Return (days_left, error) for the X.509 file at `path`. Days left can
    be negative (already expired). On any error (file missing, not readable,
    cryptography import failure, parse error) returns (None, reason) so the
    caller can render a "check skipped" line without alerting. mysqltuner
    shells out to `openssl x509 -enddate`; we use the `cryptography` library
    so the plugin works without an `openssl` binary on the path.
    """
    if not path:
        return None, 'empty path'
    if not os.path.exists(path):
        return None, f'file not found: {path}'
    if not os.access(path, os.R_OK):
        return None, f'file not readable: {path}'
    try:
        from cryptography import x509
    except ImportError:
        return None, 'python module "cryptography" not installed'
    try:
        with open(path, 'rb') as fh:
            data = fh.read()
        if b'-----BEGIN CERTIFICATE-----' in data:
            cert = x509.load_pem_x509_certificate(data)
        else:
            cert = x509.load_der_x509_certificate(data)
    except (ValueError, TypeError) as e:
        return None, f'parse error: {e}'
    # cryptography >=42 returns tz-aware *_utc; older releases (e.g. RHEL8)
    # only have the naive variant. Normalise to naive UTC so we can subtract
    # `lib.time.now(as_type='datetime')`.
    if hasattr(cert, 'not_valid_after_utc'):
        not_after = cert.not_valid_after_utc.replace(tzinfo=None)
    else:
        not_after = cert.not_valid_after
    days_left = (not_after - lib.time.now(as_type='datetime')).days
    return days_left, None


def count_weak_tls_versions(tls_versions):
    """Count enabled TLS versions that mysqltuner flags as insecure
    (TLSv1.0, TLSv1.1)."""
    weak = 0
    if 'tlsv1.0' in tls_versions:
        weak += 1
    if 'tlsv1.1' in tls_versions:
        weak += 1
    return weak


def count_modern_tls_versions(tls_versions):
    """Count enabled modern TLS versions (TLSv1.2, TLSv1.3)."""
    modern = 0
    if 'tlsv1.2' in tls_versions:
        modern += 1
    if 'tlsv1.3' in tls_versions:
        modern += 1
    return modern


def main():
    """The main function. This is where the magic happens."""

    # logic taken from mysqltuner.pl:ssl_tls_recommendations(),
    # check_local_certificates() and check_remote_user_ssl(), verified in
    # sync with MySQLTuner. mysqltuner's session-cipher probe,
    # have_ssl / require_secure_transport / tls_version checks, the
    # ssl_cert+ssl_ca local expiry audit (no key match, no CN/SAN check)
    # and the remote-user-without-SSL enumeration are all preserved 1:1.

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    severity_state = lib.base.str2state(args.SEVERITY)

    # fetch data
    mysql_connection = {
        'defaults_file': args.DEFAULTS_FILE,
        'defaults_group': args.DEFAULTS_GROUP,
        'timeout': args.TIMEOUT,
    }
    conn = lib.base.coe(lib.db_mysql.connect(mysql_connection))
    lib.base.coe(lib.db_mysql.check_privileges(conn, 'SELECT'))

    myvar = lib.db_mysql.get_all_variables(conn)
    version = (myvar.get('version') or '').lower()
    is_mariadb_10_4_plus = (
        'mariadb' in version and lib.version.version(version) >= (10, 4, 0)
    )

    session_cipher = get_session_cipher(conn)
    have_ssl = (myvar.get('have_ssl') or '').upper()
    require_secure_transport = (myvar.get('require_secure_transport') or '').upper()
    tls_versions = (myvar.get('tls_version') or '').lower()
    ssl_cert_path = myvar.get('ssl_cert') or ''
    ssl_key_path = myvar.get('ssl_key') or ''
    ssl_ca_path = myvar.get('ssl_ca') or ''

    weak_tls = count_weak_tls_versions(tls_versions)
    modern_tls = count_modern_tls_versions(tls_versions)

    cert_days, cert_skip_reason = cert_days_until_expiry(ssl_cert_path)
    ca_days, ca_skip_reason = cert_days_until_expiry(ssl_ca_path)

    remote_users = find_remote_users_without_ssl(conn, is_mariadb_10_4_plus)

    lib.db_mysql.close(conn)

    # init some vars
    state = STATE_OK
    sections = []
    findings = []
    # All recommendations land here and render once at the end as a
    # `Recommendations:\n* ...` bulleted block, regardless of which findings
    # fire. Copy-pasteable SQL and `openssl` statements are kept verbatim so
    # the admin can apply them after substituting paths/users.
    recommendations = []
    perfdata = ''

    # analyze data

    # Display tls_version with ", " between entries (server returns
    # comma-only). Computed once so finding lines stay readable.
    tls_versions_display = (myvar.get('tls_version') or '').replace(',', ', ')

    # 1. Session cipher: is the current connection itself encrypted?
    if not session_cipher:
        state = lib.base.get_worst(state, severity_state)
        findings.append(
            f'Current connection not encrypted'
            f'{lib.base.state2str(severity_state, prefix=" ")}'
        )
        recommendations.append(
            'Add `ssl=true` (or `ssl-ca=...`) to the [client] section of the '
            'monitoring `.my.cnf`, and ensure the server is reachable over TLS.'
        )
    else:
        findings.append(f'Current connection encrypted ({session_cipher})')

    # 2. have_ssl
    if have_ssl == 'DISABLED':
        state = lib.base.get_worst(state, severity_state)
        findings.append(
            f'`have_ssl` = DISABLED'
            f'{lib.base.state2str(severity_state, prefix=" ")}'
        )
        recommendations.append(
            'Enable SSL support on the server: configure `ssl_cert`, '
            '`ssl_key` and `ssl_ca`, then restart the service.'
        )
    elif have_ssl in ('YES', 'ON'):
        findings.append('`have_ssl` enabled')
    elif have_ssl:
        findings.append(f'`have_ssl` = {have_ssl}')

    # 3. require_secure_transport: forces all connections to use SSL.
    if not require_secure_transport:
        # Variable did not exist on the server (very old MySQL / MariaDB).
        # mysqltuner stays silent then; we surface it as info.
        findings.append('`require_secure_transport` not supported by this server')
    elif require_secure_transport == 'OFF':
        state = lib.base.get_worst(state, severity_state)
        findings.append(
            f'`require_secure_transport` = OFF'
            f'{lib.base.state2str(severity_state, prefix=" ")}'
        )
        recommendations.append("`SET GLOBAL require_secure_transport = 'ON';`")
        recommendations.append(
            'Persist `require_secure_transport = ON` in the server config so '
            'the setting survives restarts.'
        )
    else:
        findings.append('`require_secure_transport` = ON')

    # 4. TLS versions: TLSv1.0/1.1 are insecure; TLSv1.2 or TLSv1.3 must be on.
    if not tls_versions:
        findings.append('`tls_version` not exposed by this server')
    else:
        if weak_tls:
            state = lib.base.get_worst(state, severity_state)
            findings.append(
                f'{weak_tls} insecure TLS '
                f'{lib.txt.pluralize("version", weak_tls)} enabled '
                f'(in `{tls_versions_display}`)'
                f'{lib.base.state2str(severity_state, prefix=" ")}'
            )
            recommendations.append(
                'Set `tls_version=TLSv1.2,TLSv1.3` in your server config and '
                'restart the service.'
            )
        if modern_tls == 0:
            state = lib.base.get_worst(state, severity_state)
            findings.append(
                f'No modern TLS version enabled '
                f'(in `{tls_versions_display}`)'
                f'{lib.base.state2str(severity_state, prefix=" ")}'
            )
            recommendations.append(
                'Enable TLSv1.2 and/or TLSv1.3 via `tls_version` in your '
                'server config.'
            )
        if not weak_tls and modern_tls:
            findings.append(f'TLS versions: {tls_versions_display}')

    # 5. Server cert + key configured?
    if not ssl_cert_path and not ssl_key_path:
        state = lib.base.get_worst(state, severity_state)
        findings.append(
            f'No server certificate configured '
            f'(`ssl_cert`/`ssl_key` empty)'
            f'{lib.base.state2str(severity_state, prefix=" ")}'
        )
        recommendations.append(
            'Provision a server certificate and configure `ssl_cert`, '
            '`ssl_key` (and optionally `ssl_ca`) in the server config.'
        )

    # 6. Local cert expiry (server cert + CA). Skipped silently when the
    # path is empty; when the path is set but unreadable, mention it as
    # info so a remote check does not look broken.
    cert_state = STATE_OK
    if ssl_cert_path:
        if cert_days is not None:
            cert_state = lib.base.get_state(
                cert_days, args.WARN, args.CRIT, _operator='range',
            )
            if cert_days < 0:
                cert_state = STATE_CRIT
            state = lib.base.get_worst(state, cert_state)
            if cert_days < 0:
                findings.append(
                    f'`ssl_cert` ({ssl_cert_path}) expired '
                    f'{-cert_days} days ago'
                    f'{lib.base.state2str(cert_state, prefix=" ")}'
                )
            else:
                findings.append(
                    f'`ssl_cert` ({ssl_cert_path}) expires in {cert_days}d'
                    f'{lib.base.state2str(cert_state, prefix=" ")}'
                )
            if cert_state != STATE_OK:
                recommendations.append(
                    f'Renew `ssl_cert` ({ssl_cert_path}) and reload the '
                    'server.'
                )
        else:
            findings.append(
                f'`ssl_cert` ({ssl_cert_path}) expiry check skipped '
                f'({cert_skip_reason})'
            )

    ca_state = STATE_OK
    if ssl_ca_path:
        if ca_days is not None:
            ca_state = lib.base.get_state(
                ca_days, args.WARN, args.CRIT, _operator='range',
            )
            if ca_days < 0:
                ca_state = STATE_CRIT
            state = lib.base.get_worst(state, ca_state)
            if ca_days < 0:
                findings.append(
                    f'`ssl_ca` ({ssl_ca_path}) expired {-ca_days} days ago'
                    f'{lib.base.state2str(ca_state, prefix=" ")}'
                )
            else:
                findings.append(
                    f'`ssl_ca` ({ssl_ca_path}) expires in {ca_days}d'
                    f'{lib.base.state2str(ca_state, prefix=" ")}'
                )
            if ca_state != STATE_OK:
                recommendations.append(
                    f'Renew `ssl_ca` ({ssl_ca_path}) and reload the server.'
                )
        else:
            findings.append(
                f'`ssl_ca` ({ssl_ca_path}) expiry check skipped '
                f'({ca_skip_reason})'
            )

    # 7. Remote users without REQUIRE SSL.
    n_remote = len(remote_users)
    if n_remote:
        state = lib.base.get_worst(state, severity_state)
    findings.append(
        f'{n_remote} remote {lib.txt.pluralize("user", n_remote)} '
        f'without `REQUIRE SSL`'
        f'{lib.base.state2str(severity_state, prefix=" ") if n_remote else ""}'
    )
    for u in remote_users:
        recommendations.append(f'`ALTER USER {u["user"]} REQUIRE SSL;`')

    # build the message
    facts_text = '. '.join(findings) + '.'
    if state == STATE_OK:
        sections.append('Everything is ok. ' + facts_text)
    else:
        sections.append(facts_text)

    if recommendations:
        sections.append(
            'Recommendations:\n' + '\n'.join(f'* {r}' for r in recommendations)
        )

    msg = '\n\n'.join(sections)

    perfdata += lib.base.get_perfdata(
        'mysql_tls_session_encrypted',
        1 if session_cipher else 0,
        _min=0,
        _max=1,
    )
    perfdata += lib.base.get_perfdata(
        'mysql_tls_have_ssl',
        1 if have_ssl in ('YES', 'ON') else 0,
        _min=0,
        _max=1,
    )
    perfdata += lib.base.get_perfdata(
        'mysql_tls_required',
        1 if require_secure_transport == 'ON' else 0,
        _min=0,
        _max=1,
    )
    perfdata += lib.base.get_perfdata(
        'mysql_tls_weak_protocol_versions',
        weak_tls,
        _min=0,
        _max=2,
    )
    perfdata += lib.base.get_perfdata(
        'mysql_tls_modern_protocol_versions',
        modern_tls,
        _min=0,
        _max=2,
    )
    if cert_days is not None:
        perfdata += lib.base.get_perfdata(
            'mysql_tls_ssl_cert_days_until_expiry',
            cert_days,
            uom='d',
            warn=args.WARN,
            crit=args.CRIT,
        )
    if ca_days is not None:
        perfdata += lib.base.get_perfdata(
            'mysql_tls_ssl_ca_days_until_expiry',
            ca_days,
            uom='d',
            warn=args.WARN,
            crit=args.CRIT,
        )
    perfdata += lib.base.get_perfdata(
        'mysql_tls_remote_users_without_ssl',
        len(remote_users),
        _min=0,
    )

    # over and out
    lib.base.oao(msg, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
