import logging import shutil import sys import os import signal import subprocess import argparse import psutil import rich import asyncio from rich.logging import RichHandler from pathlib import Path import antidrift.client from antidrift.daemon import AntiDriftDaemon from antidrift.config import Config from dbus.mainloop.glib import DBusGMainLoop DBusGMainLoop(set_as_default=True) signal.signal(signal.SIGINT, signal.SIG_DFL) def get_args(): parser = argparse.ArgumentParser(description="AntiDrift CLI.") parser.add_argument("--daemon", action="store_true", help="run daemon") parser.add_argument("--status", action="store_true", help="get status from daemon") parser.add_argument("--tailf", action="store_true", help="tail -f log file") parser.add_argument("--start", metavar="whiteblock", nargs="+", help="start whiteblocks") parser.add_argument("--stop", action="store_true", help="stop session") parser.add_argument("--pause", action="store_true", help="pause antidrift") parser.add_argument("--unpause", action="store_true", help="unpause antidrift") parser.add_argument("--schedule", metavar="blackblock", help="schedule blackblock") args = parser.parse_args() return args def init_logging(log_file: Path, dev_mode: bool = False): class DuplicateFilter(logging.Filter): def filter(self, record) -> bool: current_log = (record.module, record.levelno, record.msg) if current_log != getattr(self, "last_log", None): self.last_log = current_log return True return False if dev_mode: format_str = "%(message)s" # RichHandler will handle the formatting logging.basicConfig( level=logging.DEBUG, format=format_str, datefmt="%a %H:%M:%S", handlers=[RichHandler(rich_tracebacks=True, markup=True)], ) else: format_str = ( "[bold pale_green3]%(asctime)s[/bold pale_green3] | " "[light_steel_blue]%(levelname)-8s[/light_steel_blue] | " "%(message)s" ) logging.basicConfig( filename=log_file, format=format_str, datefmt="%a %H:%M:%S", encoding="utf-8", level=logging.DEBUG, ) logger = logging.getLogger() logger.addFilter(DuplicateFilter()) def check_for_xdotool(): """Check if xdotool is in path and exit if not""" result = shutil.which("xdotool") if not result: logging.critical("Please install xdotool") sys.exit(1) def kill_existing_antidrift(): current_pid = os.getpid() for proc in psutil.process_iter(['pid', 'name', 'exe']): if proc.info['name'] == '/usr/bin/antidrift' or proc.info['exe'] == '/usr/bin/antidrift': if proc.info['pid'] == current_pid: continue # don't kill ourselves try: proc.kill() except psutil.AccessDenied: pass except psutil.NoSuchProcess: pass def main_daemon(): if os.geteuid() == 0: newpid = os.fork() if newpid == 0: config = Config.load(os.path.expanduser("~/.config/antidrift.yaml")) init_logging(config.daemon_log_file) daemon = AntiDriftDaemon(config) asyncio.run(daemon.run()) else: if sys.argv[0] == "antidrift": kill_existing_antidrift() cmd = ["sudo", "antidrift", "--daemon"] subprocess.Popen(cmd) else: config = Config.load(os.path.expanduser("~/.config/antidrift.yaml")) init_logging(config.daemon_log_file, dev_mode=True) daemon = AntiDriftDaemon(config) asyncio.run(daemon.run(debug=True)) def main() -> None: """Main routine that dispatches to client or daemon""" check_for_xdotool() args = get_args() if args.daemon: main_daemon() else: config = Config.load(os.path.expanduser("~/.config/antidrift.yaml")) asyncio.run(antidrift.client.run(args, config)) if __name__ == "__main__": main()