#!/usr/lib64/linuxfabrik-monitoring-plugins/venv/bin/python
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import socket
import ssl
import sys
import time
import urllib.parse
from pathlib import Path

import lib.args
import lib.base
import lib.time
from lib.globals import STATE_CRIT, STATE_OK, STATE_UNKNOWN, STATE_WARN

try:
    from cryptography import x509
    from cryptography.hazmat.primitives import hashes, serialization
except ImportError:
    print('Python module "cryptography" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026051201'

DESCRIPTION = """Inspects an X.509 certificate and alerts on days remaining until expiry,
hostname mismatch and chain verification failures. Sources via --source: `url` fetches the
certificate from a TLS endpoint and verifies the chain against the system trust store by
default (override with --ca-file); `file` reads one or many certificate files via glob
expansion (PEM or DER). PEM bundles expand to one item per certificate, so a fullchain.pem
produces a row for the leaf and a row for each intermediate. With --source url the plugin
only runs the TLS handshake and reads the server certificate; no HTTP request is sent.
That means it works for any "TLS from start" service, not only HTTPS: IMAPS (port 993),
LDAPS (636), SMTPS (465), AMQPS (5671), MQTTS (8883), custom TLS ports - just point --url
at the right host and port (`https://mail.example.com:993/` inspects the IMAPS cert).
STARTTLS protocols that upgrade an existing plaintext connection (SMTP submission on 587,
IMAP on 143, LDAP on 389) are not supported. Hostname mismatch and chain verification
failures share one --severity (warn or crit) and only apply to --source url. Expired
certificates are unconditionally reported as CRIT. With --source file, the worst state
across all matched files drives the plugin state."""

DEFAULT_CRIT = '5:'
DEFAULT_LENGTHY = False
DEFAULT_SEVERITY = 'warn'
DEFAULT_TIMEOUT = 8
DEFAULT_WARN = '14:'

VALID_SEVERITIES = ('crit', 'warn')
VALID_SOURCES = ('file', 'url')


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--ca-file',
        help='Path to a CA bundle in PEM format used for chain verification. '
        'Default uses the system trust store.',
        dest='CA_FILE',
    )

    parser.add_argument(
        '--client-cert',
        help='Path to a client certificate in PEM format for mutual TLS.',
        dest='CLIENT_CERT',
    )

    parser.add_argument(
        '--client-key',
        help='Path to the client certificate private key in PEM format.',
        dest='CLIENT_KEY',
    )

    parser.add_argument(
        '-c',
        '--critical',
        help='CRIT threshold for days remaining until the certificate expires. '
        'Supports Nagios ranges. '
        'Example: `5:` (CRIT when below 5 days). '
        'Default: %(default)s',
        dest='CRIT',
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--filename',
        help='Path to a certificate file or a glob pattern matching multiple certificate '
        'files. Required when --source=file. Files are read as PEM or DER (autodetected); '
        'when a PEM bundle contains multiple certificates (typical for fullchain.pem or a '
        'CA bundle), each certificate becomes its own row in the output. Globs follow '
        'Python conventions: `*` matches one path segment, `**` matches across '
        'directories. Always quote the pattern in shells so that the shell does not '
        'expand the wildcard before the plugin sees it. '
        'Example: `--filename=\'/etc/ssl/certs/*.pem\'`. '
        'Recursive example: `--filename=\'/etc/letsencrypt/live/**/cert.pem\'`',
        dest='FILENAME',
    )

    parser.add_argument(
        '--insecure',
        help='Skip chain and hostname verification entirely. The certificate is still '
        'fetched and inspected, but the chain verdict is reported as "verification skipped".',
        dest='INSECURE',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--lengthy',
        help=lib.args.help('--lengthy'),
        dest='LENGTHY',
        action='store_true',
        default=DEFAULT_LENGTHY,
    )

    parser.add_argument(
        '--severity',
        help='Severity assigned to hostname mismatch and chain verification failures. '
        'Defaults to %(default)s so that operators running internal CAs are not paged '
        'by trust issues that are expected in their environment. '
        'Set to `crit` to enforce strict trust.',
        dest='SEVERITY',
        choices=VALID_SEVERITIES,
        default=DEFAULT_SEVERITY,
    )

    parser.add_argument(
        '--sni-hostname',
        help='SNI hostname sent during the TLS handshake and used for hostname verification. '
        'Useful when --url points at an IP address or a load balancer that needs an explicit '
        'SNI. Default uses the hostname from --url.',
        dest='SNI_HOSTNAME',
    )

    parser.add_argument(
        '--source',
        help='Where the certificate is fetched from. '
        '`url` fetches it from a TLS endpoint (requires --url). '
        '`file` reads it from one or many local files (requires --filename, supports glob '
        'patterns). '
        '`p12` and `jks` are reserved for future expansion. '
        'Default: %(default)s',
        dest='SOURCE',
        choices=VALID_SOURCES,
        default='url',
    )

    parser.add_argument(
        '--timeout',
        help=lib.args.help('--timeout') + ' Default: %(default)s (seconds)',
        dest='TIMEOUT',
        type=int,
        default=DEFAULT_TIMEOUT,
    )

    parser.add_argument(
        '--url',
        help='URL of the TLS endpoint to inspect. Required when --source=url. '
        'Example: `https://www.example.com/`',
        dest='URL',
    )

    parser.add_argument(
        '-w',
        '--warning',
        help='WARN threshold for days remaining until the certificate expires. '
        'Supports Nagios ranges. '
        'Example: `14:` (WARN when below 14 days). '
        'Default: %(default)s',
        dest='WARN',
        default=DEFAULT_WARN,
    )

    args, _ = parser.parse_known_args()
    return args


def get_cert_from_url(args):
    """Open a TLS connection, capture the peer certificate in DER form, the negotiated TLS
    version and the wall-clock handshake time, and decide the chain/hostname verdict.

    Returns (success, result). On success result is a dict with the keys consumed by main().
    """
    parsed = urllib.parse.urlsplit(args.URL)
    host = parsed.hostname
    if not host:
        return False, f'Cannot parse a hostname from --url "{args.URL}"'
    port = parsed.port or 443
    server_hostname = args.SNI_HOSTNAME or host

    chain_ok = True
    chain_reason = ''

    if not args.INSECURE:
        try:
            secure_ctx = ssl.create_default_context(cafile=args.CA_FILE)
            secure_ctx.minimum_version = ssl.TLSVersion.TLSv1_2
            if args.CLIENT_CERT:
                secure_ctx.load_cert_chain(args.CLIENT_CERT, keyfile=args.CLIENT_KEY)
            with (
                socket.create_connection((host, port), timeout=args.TIMEOUT) as sock,
                secure_ctx.wrap_socket(sock, server_hostname=server_hostname),
            ):
                pass
        except ssl.SSLCertVerificationError as e:
            chain_ok = False
            chain_reason = e.verify_message or str(e)
        except ssl.SSLError as e:
            chain_ok = False
            chain_reason = str(e)
        except (socket.timeout, OSError) as e:
            return False, f'Cannot connect to {host}:{port}: {e}'

    insecure_ctx = ssl.create_default_context()
    insecure_ctx.minimum_version = ssl.TLSVersion.TLSv1_2
    insecure_ctx.check_hostname = False
    insecure_ctx.verify_mode = ssl.CERT_NONE
    if args.CLIENT_CERT:
        insecure_ctx.load_cert_chain(args.CLIENT_CERT, keyfile=args.CLIENT_KEY)

    try:
        t0 = time.monotonic()
        with (
            socket.create_connection((host, port), timeout=args.TIMEOUT) as sock,
            insecure_ctx.wrap_socket(sock, server_hostname=server_hostname) as tls_sock,
        ):
            handshake_seconds = time.monotonic() - t0
            cert_der = tls_sock.getpeercert(True)
            tls_version = tls_sock.version()
    except (socket.timeout, OSError) as e:
        return False, f'Cannot connect to {host}:{port}: {e}'
    except ssl.SSLError as e:
        return False, f'TLS handshake failed for {host}:{port}: {e}'

    if not cert_der:
        return False, f'No server certificate received from {host}:{port}'

    return True, {
        'cert_der': cert_der,
        'chain_ok': chain_ok,
        'chain_reason': chain_reason,
        'handshake_seconds': handshake_seconds,
        'host': host,
        'port': port,
        'sni': server_hostname,
        'tls_version': tls_version,
    }


def _looks_like_cert(data):
    """Return True if `data` has a recognisable certificate envelope. PEM certificates carry
    the `BEGIN CERTIFICATE` marker; DER certificates always start with the ASN.1 SEQUENCE
    tag (`0x30`). Files that match neither (private keys, OpenSSL trust bundles, text
    files) are not certificates and should be skipped silently when expanding a glob.
    """
    if b'-----BEGIN CERTIFICATE-----' in data:
        return True
    return data[:1] == b'\x30'


def _load_all_certs(data):
    """Parse every certificate from `data` and return a list of DER-encoded bytes. Should
    only be called on data that already passed `_looks_like_cert()`. Raises `ValueError`
    when the envelope looks like a certificate (PEM marker present or DER prefix detected)
    but the content cannot be parsed.

    PEM bundles can carry multiple `BEGIN CERTIFICATE`/`END CERTIFICATE` blocks (typical
    for `fullchain.pem` or system CA bundles); each block becomes its own list entry.
    DER files always carry exactly one certificate.
    """
    certs = []
    if b'-----BEGIN CERTIFICATE-----' in data:
        # walk the PEM block by block so a single corrupt block fails the whole call
        # rather than silently dropping the rest of the bundle
        marker_begin = b'-----BEGIN CERTIFICATE-----'
        marker_end = b'-----END CERTIFICATE-----'
        cursor = 0
        while True:
            start = data.find(marker_begin, cursor)
            if start < 0:
                break
            end = data.find(marker_end, start)
            if end < 0:
                raise ValueError(
                    'truncated PEM bundle: BEGIN CERTIFICATE without matching END',
                )
            block = data[start:end + len(marker_end)]
            cert = x509.load_pem_x509_certificate(block)
            certs.append(cert.public_bytes(serialization.Encoding.DER))
            cursor = end + len(marker_end)
    else:
        cert = x509.load_der_x509_certificate(data)
        certs.append(cert.public_bytes(serialization.Encoding.DER))
    return certs


def get_certs_from_files(args):
    """Glob-expand --filename, read each match, parse every certificate per file.

    Returns (success, items) where items is a list of dicts with `cert_der` and `label`.
    PEM bundles expand to one item per certificate so `fullchain.pem` produces a leaf row
    plus an intermediate row, and a CA bundle becomes one row per anchor. Files that don't
    look like certificates at all are silently skipped, which makes recursive globs safe
    even when they hit private keys or trust bundles. Files that look like certificates
    but fail to parse are reported as a real error so corrupt certs surface clearly.
    """
    path = Path(args.FILENAME)
    matches = sorted(Path(path.anchor).glob(str(path.relative_to(path.anchor))))
    items = []
    matched_files = 0
    for match in matches:
        if not match.is_file():
            continue
        matched_files += 1
        try:
            data = match.read_bytes()
        except OSError as e:
            return False, f'Cannot read {match}: {e}'
        if not _looks_like_cert(data):
            continue
        try:
            cert_ders = _load_all_certs(data)
        except (ValueError, TypeError) as e:
            return False, f'Cannot parse certificate from {match}: {e}'
        for cert_der in cert_ders:
            items.append({
                'cert_der': cert_der,
                'chain_verdict': None,
                'handshake_seconds': None,
                'label': str(match),
                'tls_version': None,
            })
    if not items:
        if matched_files == 0:
            return False, f'No files match "{args.FILENAME}"'
        return False, (
            f'No parseable certificates among {matched_files} file(s)'
            f' matching "{args.FILENAME}"'
        )
    return True, items


def collect_items(args):
    """Return (success, items). Each item has the keys consumed by main(): `cert_der`,
    `label`, `chain_verdict`, `tls_version`, `handshake_seconds`. Wraps the per-source
    fetchers so main() can iterate uniformly over URL and file sources.
    """
    if args.SOURCE == 'url':
        if not args.URL:
            return False, '--url is required when --source=url'
        ok, info = get_cert_from_url(args)
        if not ok:
            return False, info
        if args.INSECURE:
            chain_verdict = 'verification skipped'
        elif info['chain_ok']:
            chain_verdict = 'verified'
        else:
            chain_verdict = f'unverified ({info["chain_reason"]})'
        return True, [{
            'cert_der': info['cert_der'],
            'chain_verdict': chain_verdict,
            'handshake_seconds': info['handshake_seconds'],
            'label': info['host'],
            'tls_version': info['tls_version'],
        }]
    if args.SOURCE == 'file':
        if not args.FILENAME:
            return False, '--filename is required when --source=file'
        return get_certs_from_files(args)
    return False, f'Source "{args.SOURCE}" is not implemented yet'


def parse_cert(cert_der):
    """Decode a DER-encoded X.509 certificate into a flat dict of human-readable fields."""
    cert = x509.load_der_x509_certificate(cert_der)

    def _cn(name):
        attrs = name.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
        return attrs[0].value if attrs else ''

    pub_key = cert.public_key()
    key_type = type(pub_key).__name__.replace('PublicKey', '')
    try:
        key_size = pub_key.key_size
    except AttributeError:
        key_size = None
    # RSA only: extract the public exponent. Standard value is 65537; unusual values
    # (3, 17) are a configuration smell.
    try:
        rsa_exponent = pub_key.public_numbers().e
    except AttributeError:
        rsa_exponent = None

    sans = []
    try:
        san_ext = cert.extensions.get_extension_for_oid(
            x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
        )
        for name in san_ext.value:
            try:
                sans.append(str(name.value))
            except AttributeError:
                pass
    except x509.ExtensionNotFound:
        pass

    # OCSP Must-Staple per RFC 7633: TLS Feature extension (OID 1.3.6.1.5.5.7.1.24)
    # carries feature codes; 5 = status_request (Must-Staple). cryptography exposes the
    # codes as TLSFeatureType enum members; .value contains the integer.
    must_staple = False
    try:
        tls_feature_ext = cert.extensions.get_extension_for_oid(
            x509.ObjectIdentifier('1.3.6.1.5.5.7.1.24'),
        )
        for feature in tls_feature_ext.value:
            code = getattr(feature, 'value', feature)
            if code == 5:
                must_staple = True
                break
    except x509.ExtensionNotFound:
        pass

    # cryptography >=42 returns tz-aware UTC datetimes via *_utc; older releases (e.g. the
    # one shipped on RHEL8) only have the naive variants. Strip tzinfo for downstream
    # arithmetic against `lib.time.now(as_type='datetime')`, which is naive UTC.
    if hasattr(cert, 'not_valid_after_utc'):
        not_after = cert.not_valid_after_utc.replace(tzinfo=None)
        not_before = cert.not_valid_before_utc.replace(tzinfo=None)
    else:
        not_after = cert.not_valid_after
        not_before = cert.not_valid_before

    return {
        'issuer_cn': _cn(cert.issuer),
        'key_size': key_size,
        'key_type': key_type,
        'must_staple': must_staple,
        'not_after': not_after,
        'not_before': not_before,
        'rsa_exponent': rsa_exponent,
        'sans': sans,
        'serial_hex': f'{cert.serial_number:X}',
        'sha256_fingerprint': cert.fingerprint(hashes.SHA256()).hex(':').upper(),
        'sig_algorithm': cert.signature_algorithm_oid._name,
        'subject_cn': _cn(cert.subject),
    }


def _format_days_left(days_left):
    """Render the days-remaining counter as a human-friendly phrase. Negative values mean
    the certificate is already past its `notAfter` date, so we say "expired N days ago"
    rather than the unintuitive "-N days left".
    """
    if days_left < 0:
        return f'expired {-days_left} days ago'
    if days_left == 0:
        return 'expires today'
    return f'{days_left}d left'


def _evaluate_item(item, args, now, severity_state):
    """Compute the per-item state, the human-readable summary fragment and the parsed cert
    fields. Pure function so the same logic drives URL and file sources.
    """
    cert = parse_cert(item['cert_der'])
    days_left = (cert['not_after'] - now).days

    item_state = lib.base.get_state(days_left, args.WARN, args.CRIT, _operator='range')
    if days_left < 0:
        item_state = STATE_CRIT
    if item['chain_verdict'] and item['chain_verdict'].startswith('unverified'):
        item_state = lib.base.get_worst(item_state, severity_state)

    return cert, days_left, item_state


def _build_lengthy_table(item, cert):
    """Render the full --lengthy field/value table for a single certificate."""
    key_label = (
        f'{cert["key_type"]} {cert["key_size"]}'
        if cert['key_size']
        else cert['key_type']
    )
    if cert['rsa_exponent'] is not None:
        key_label += f' (e={cert["rsa_exponent"]})'
    sans_label = ', '.join(cert['sans']) if cert['sans'] else '-'
    rows = [
        {'k': 'Subject CN', 'v': cert['subject_cn'] or '(no CN)'},
        {'k': 'Issuer CN', 'v': cert['issuer_cn'] or '(no CN)'},
        {'k': 'Serial', 'v': cert['serial_hex']},
        {'k': 'Signature Algorithm', 'v': cert['sig_algorithm']},
        {'k': 'Public Key', 'v': key_label},
        {'k': 'SANs', 'v': sans_label},
        {'k': 'Not Before', 'v': cert['not_before'].isoformat() + 'Z'},
        {'k': 'Not After', 'v': cert['not_after'].isoformat() + 'Z'},
        {'k': 'SHA-256 Fingerprint', 'v': cert['sha256_fingerprint']},
        {'k': 'OCSP Must-Staple', 'v': 'yes' if cert['must_staple'] else 'no'},
    ]
    if item['tls_version'] is not None:
        rows.append({'k': 'TLS Version', 'v': item['tls_version']})
    if item['chain_verdict'] is not None:
        rows.append({'k': 'Chain', 'v': item['chain_verdict']})
    return lib.base.get_table(rows, ['k', 'v'], header=['Field', 'Value'])


def main():
    """The main function. This is where the magic happens."""
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    severity_state = STATE_CRIT if args.SEVERITY == 'crit' else STATE_WARN
    items = lib.base.coe(collect_items(args))
    now = lib.time.now(as_type='datetime')

    state = STATE_OK
    days_left_values = []
    summary_table = []
    lengthy_blocks = []
    handshake_seconds = None

    for item in items:
        cert, days_left, item_state = _evaluate_item(item, args, now, severity_state)
        state = lib.base.get_worst(state, item_state)
        days_left_values.append(days_left)
        if item['handshake_seconds'] is not None:
            handshake_seconds = item['handshake_seconds']

        summary_table.append({
            'label': item['label'],
            'subject_cn': cert['subject_cn'] or '(no CN)',
            'days_left': _format_days_left(days_left),
            'state': lib.base.state2str(item_state, empty_ok=False),
        })
        if args.LENGTHY:
            lengthy_blocks.append((item['label'], _build_lengthy_table(item, cert)))

    if len(items) == 1:
        only = summary_table[0]
        chain = items[0]['chain_verdict']
        # state2str goes at the end of the line so [WARNING] / [CRITICAL] sits next to the
        # reason (chain unverified, expired, etc.) rather than always next to days_left
        msg = f'{only["subject_cn"]}, {only["days_left"]}'
        if chain is not None:
            msg += f', chain {chain}'
        msg += lib.base.state2str(state, prefix=' ')
    else:
        worst_idx = days_left_values.index(min(days_left_values))
        worst = days_left_values[worst_idx]
        worst_subject = summary_table[worst_idx]['subject_cn']
        msg = (
            f'{len(items)} certificates checked'
            f', worst {_format_days_left(worst)} ({worst_subject})'
            f'{lib.base.state2str(state, prefix=" ")}'
        )
        msg += '\n\n' + lib.base.get_table(
            summary_table,
            ['label', 'subject_cn', 'days_left', 'state'],
            header=['File', 'Subject CN', 'Status', 'State'],
        )

    if args.LENGTHY:
        for label, block in lengthy_blocks:
            # `lib.base.get_table` already ends with a newline, so a single `\n` here
            # produces exactly one blank line between blocks
            if len(items) > 1:
                msg += f'\n{label}\n{block}'
            else:
                msg += f'\n\n{block}'

    perfdata = lib.base.get_perfdata(
        'cert_days_left',
        min(days_left_values),
        uom='d',
        warn=args.WARN,
        crit=args.CRIT,
    )
    if handshake_seconds is not None:
        perfdata += lib.base.get_perfdata(
            'tls_handshake_time',
            round(handshake_seconds, 4),
            uom='s',
            _min=0,
        )

    lib.base.oao(msg, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
