antidrift/main.py

133 lines
4.3 KiB
Python

import logging
import shutil
import sys
import os
import subprocess
import argparse
import psutil
import asyncio
from rich.logging import RichHandler
from pathlib import Path
import antidrift.client
from antidrift.daemon import AntiDriftDaemon
from antidrift.config import Config
def get_args():
parser = argparse.ArgumentParser(description="AntiDrift CLI.")
parser.add_argument("--daemon", action="store_true", help="run daemon")
parser.add_argument("--evaluate", action="store_true", help="evaluate day")
parser.add_argument(
"--intention", metavar="intention", help="set intention", default=None, type=str
)
parser.add_argument(
"--block", metavar="block", help="add to block", default=None, type=str
)
parser.add_argument("--pause", action="store_true", help="pause antidrift")
parser.add_argument("--schedule", metavar="blackblock", help="schedule blackblock")
parser.add_argument(
"--start", metavar="whiteblock", nargs="+", help="start whiteblocks"
)
parser.add_argument("--status", action="store_true", help="get status from daemon")
parser.add_argument("--stop", action="store_true", help="stop session")
parser.add_argument("--tailf", action="store_true", help="tail -f log file")
parser.add_argument("--unpause", action="store_true", help="unpause antidrift")
args = parser.parse_args()
return args
def init_logging(log_file: Path, dev_mode: bool = False):
class DuplicateFilter(logging.Filter):
def filter(self, record) -> bool:
current_log = (record.module, record.levelno, record.msg)
if current_log != getattr(self, "last_log", None):
self.last_log = current_log
return True
return False
if dev_mode:
format_str = "%(message)s" # RichHandler will handle the formatting
logging.basicConfig(
level=logging.DEBUG,
format=format_str,
datefmt="%a %H:%M:%S",
handlers=[RichHandler(rich_tracebacks=True, markup=True)],
)
else:
format_str = (
"[bold pale_green3]%(asctime)s[/bold pale_green3] | "
"[light_steel_blue]%(levelname)-8s[/light_steel_blue] | "
"%(message)s"
)
logging.basicConfig(
filename=log_file,
format=format_str,
datefmt="%a %H:%M:%S",
encoding="utf-8",
level=logging.DEBUG,
)
logger = logging.getLogger()
logger.addFilter(DuplicateFilter())
def check_for_xdotool():
"""Check if xdotool is in path and exit if not"""
result = shutil.which("xdotool")
if not result:
logging.critical("Please install xdotool")
sys.exit(1)
def kill_existing_antidrift():
current_pid = os.getpid()
for proc in psutil.process_iter(["pid", "name", "exe"]):
if (
proc.info["name"] == "/usr/bin/antidrift"
or proc.info["exe"] == "/usr/bin/antidrift"
):
if proc.info["pid"] == current_pid:
continue # don't this process
try:
proc.kill()
except psutil.AccessDenied:
pass
except psutil.NoSuchProcess:
pass
def main_daemon():
if os.geteuid() == 0:
newpid = os.fork()
if newpid == 0:
config = Config.load(os.path.expanduser("~/.config/antidrift.yaml"))
init_logging(config.daemon_log_file)
daemon = AntiDriftDaemon(config)
asyncio.run(daemon.run())
else:
if sys.argv[0] == "antidrift":
kill_existing_antidrift()
cmd = ["sudo", "antidrift", "--daemon"]
subprocess.Popen(cmd)
else:
config = Config.load(os.path.expanduser("~/.config/antidrift.yaml"))
init_logging(config.daemon_log_file, dev_mode=True)
daemon = AntiDriftDaemon(config)
asyncio.run(daemon.run(debug=True))
def main() -> None:
"""Main routine that dispatches to client or daemon"""
check_for_xdotool()
args = get_args()
if args.daemon:
main_daemon()
else:
config = Config.load(os.path.expanduser("~/.config/antidrift.yaml"))
asyncio.run(antidrift.client.run(args, config))
if __name__ == "__main__":
main()