Reformat code and add option to add keywords

This commit is contained in:
2023-06-07 17:52:34 -04:00
parent f058c2235d
commit d97a4885cb
8 changed files with 103 additions and 227 deletions

View File

@@ -6,6 +6,7 @@ class AuthExternal(Authenticator):
:class:`MessageBus <dbus_next.message_bus.BaseMessageBus>`.
:sealso: https://dbus.freedesktop.org/doc/dbus-specification.html#auth-protocol
"""
def __init__(self, user_uid):
self.user_uid = user_uid
self.negotiate_unix_fd = False
@@ -14,7 +15,7 @@ class AuthExternal(Authenticator):
def _authentication_start(self, negotiate_unix_fd=False) -> str:
self.negotiate_unix_fd = negotiate_unix_fd
hex_uid = str(self.user_uid).encode().hex()
return f'AUTH EXTERNAL {hex_uid}'
return f"AUTH EXTERNAL {hex_uid}"
def _receive_line(self, line: str):
response, args = _AuthResponse.parse(line)
@@ -29,5 +30,4 @@ class AuthExternal(Authenticator):
if response is _AuthResponse.AGREE_UNIX_FD:
return "BEGIN"
raise AuthError(f'authentication failed: {response.value}: {args}')
raise AuthError(f"authentication failed: {response.value}: {args}")

View File

@@ -1,4 +1,4 @@
import asyncio
import time
from dbus_next.aio import MessageBus
from dbus_next import BusType
from dbus_next.errors import DBusError
@@ -20,7 +20,6 @@ async def get_dbus_interface():
async def run(args: Namespace, config: Config):
if args.evaluate:
evaluate(config)
return
@@ -35,6 +34,8 @@ async def run(args: Namespace, config: Config):
reply = await interface.call_stop()
elif args.pause:
reply = await interface.call_pause()
elif args.block is not None:
reply = await interface.call_block(args.block)
elif args.intention is not None:
reply = await interface.call_intention(args.intention)
elif args.unpause:

View File

@@ -36,18 +36,33 @@ class Config(BaseModel):
config = cls(**config_dict)
config.config_file = Path(config_file)
# Expand the paths for the log files
config.window_log_file = Path(os.path.expanduser(config.window_log_file))
config.daemon_log_file = Path(os.path.expanduser(config.daemon_log_file))
config.client_log_file = Path(os.path.expanduser(config.client_log_file))
config.window_log_file = Path(
os.path.expanduser(config.window_log_file))
config.daemon_log_file = Path(
os.path.expanduser(config.daemon_log_file))
config.client_log_file = Path(
os.path.expanduser(config.client_log_file))
return config
def save(self) -> None:
config_file = self.config_file
config_dict = self.dict()
# convert Path objects to strings
for key, value in config_dict.items():
if isinstance(value, Path):
config_dict[key] = str(value)
with open(config_file, "w") as f:
yaml.safe_dump(config_dict, f)
class State(BaseModel):
active_blackblocks: List[Block] = []
active_whiteblocks: List[Block] = []
inactive_blackblocks: List[Block] = []
pause: bool = False
intention: str = ''
intention: str = ""
class Config:
extra = "forbid"

View File

@@ -4,23 +4,21 @@ import logging
import os
import pwd
import re
import sys
import time
import asyncio
import antidrift.xwindow as xwindow
from antidrift.xwindow import XWindow
from antidrift.config import Config, State, Block
from antidrift.auth import AuthExternal
from typing import List, Optional
from dbus_next.aio import MessageBus
from dbus_next.service import ServiceInterface, method
from dbus_next import Variant, BusType
from dbus_next import BusType
BUS_NAME = "com.antidrift"
IFACE = "com.antidrift"
OPATH = "/com/antidrift"
s = "no pyflakes warning"
class AntiDriftDaemon(ServiceInterface):
@@ -29,7 +27,8 @@ class AntiDriftDaemon(ServiceInterface):
self.config = config
self.reset_block_state()
self.enforce_count = 0
self.enforce_value = int(config.enforce_delay_ms / config.polling_cycle_ms)
self.enforce_value = int(
config.enforce_delay_ms / config.polling_cycle_ms)
async def init_bus(self):
"""
@@ -39,13 +38,15 @@ class AntiDriftDaemon(ServiceInterface):
the original effective UID to maintain the appropriate privilege
levels.
"""
user_name = os.environ.get("SUDO_USER", pwd.getpwuid(os.getuid()).pw_name)
user_name = os.environ.get(
"SUDO_USER", pwd.getpwuid(os.getuid()).pw_name)
user_uid = pwd.getpwnam(user_name).pw_uid
euid = os.geteuid()
os.seteuid(user_uid)
auth = AuthExternal(user_uid)
bus_address = f"unix:path=/run/user/{user_uid}/bus"
bus = MessageBus(bus_address=bus_address, bus_type=BusType.SESSION, auth=auth)
bus = MessageBus(bus_address=bus_address,
bus_type=BusType.SESSION, auth=auth)
await bus.connect()
bus.export(OPATH, self)
await bus.request_name(BUS_NAME)
@@ -53,7 +54,7 @@ class AntiDriftDaemon(ServiceInterface):
return bus
async def run(self, debug: bool = False):
bus = await self.init_bus()
_ = await self.init_bus()
async def _enforce():
while True:
@@ -73,7 +74,7 @@ class AntiDriftDaemon(ServiceInterface):
asyncio.create_task(_enforce())
asyncio.create_task(_log())
xwindow.notify(f"AntiDrift running.")
xwindow.notify("Antidrift running.")
stop = asyncio.Event()
await stop.wait()
@@ -85,13 +86,14 @@ class AntiDriftDaemon(ServiceInterface):
)
@method()
def start(self, whiteblocks: 'as') -> 's':
def start(self, whiteblocks: "as") -> "s":
self.reset_block_state()
all_whiteblocks = {wb.name: wb for wb in self.config.whiteblocks}
success_wbs, fail_blocks = [], []
for block_name in whiteblocks:
if block_name in all_whiteblocks:
self.state.active_whiteblocks.append(all_whiteblocks[block_name])
self.state.active_whiteblocks.append(
all_whiteblocks[block_name])
success_wbs.append(block_name)
else:
fail_blocks.append(block_name)
@@ -107,7 +109,7 @@ class AntiDriftDaemon(ServiceInterface):
return r
@method()
def schedule(self, blackblock_name: 's') -> 's':
def schedule(self, blackblock_name: "s") -> "s":
"""Schedule blackblock based if it has a non-zero timeout value."""
all_blackblocks = {bb.name: bb for bb in self.config.blackblocks}
if blackblock_name not in all_blackblocks:
@@ -132,35 +134,47 @@ class AntiDriftDaemon(ServiceInterface):
return m
@method()
def stop(self) -> 's':
def stop(self) -> "s":
self.reset_block_state()
m = "Blacklist only mode."
logging.info(m)
return m
@method()
def pause(self) -> 's':
def pause(self) -> "s":
self.state.pause = True
m = "Antidrift paused."
logging.info(m)
return m
@method()
def unpause(self) -> 's':
def unpause(self) -> "s":
self.state.pause = False
m = "Antidrift unpaused."
logging.info(m)
return m
@method()
def intention(self, intention: 's') -> 's':
def intention(self, intention: "s") -> "s":
self.state.intention = intention
m = f"Antidrift intention set to '{intention}'"
m = f"Antidrift intention set to '{intention}'."
logging.info(m)
return m
@method()
def status(self) -> 's':
def block(self, block: "s") -> "s":
# self.state.intention = intention
if self.state.active_blackblocks and \
block not in self.state.active_blackblocks[0].keywords:
self.state.active_blackblocks[0].keywords.append(block)
self.config.save()
m = f"✅Antidrift add block '{block}'."
logging.info(m)
return m
return f"⚠️ '{block}' not added."
@method()
def status(self) -> "s":
white_active = bool(self.state.active_whiteblocks)
black_active = bool(self.state.active_blackblocks)
m = "🟢 ad "
@@ -172,7 +186,8 @@ class AntiDriftDaemon(ServiceInterface):
match (white_active, black_active):
case (True, _):
m += "wb: "
m += " ".join(map(lambda b: b.name, self.state.active_whiteblocks))
m += " ".join(map(lambda b: b.name,
self.state.active_whiteblocks))
if inactive_bbs:
m += " "
m += inactive_bbs
@@ -197,24 +212,28 @@ class AntiDriftDaemon(ServiceInterface):
def log_window(self):
window = XWindow()
utc_timestamp = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
utc_timestamp = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
# Remove time strings and numbers with decimal from window.name
window_name = re.sub(r'\b\d\d:\d\d\b', '', window.name)
window_name = re.sub(r'\b\d+.\d\d\b', '', window_name)
window_name = re.sub(r'[+-]\d+.\d+%', '', window_name)
window_name = re.sub(r"\b\d\d:\d\d\b", "", window.name)
window_name = re.sub(r"\b\d+.\d\d\b", "", window_name)
window_name = re.sub(r"[+-]\d+.\d+%", "", window_name)
intention = self.state.intention
log_line = [utc_timestamp, window_name, window.cls, intention]
with self.config.window_log_file.open('a', newline='') as f:
with self.config.window_log_file.open("a", newline="") as f:
writer = csv.writer(f)
writer.writerow(log_line)
async def enforce_pause(self):
xwindow.notify("Goint to minimize window because of pause...")
if XWindow().is_active() is False:
return
xwindow.notify("⚠️ No active windows during pause!")
for _ in range(8):
await asyncio.sleep(1)
if not XWindow().is_active() or self.state.pause is False:
window = XWindow()
if not window.is_active() or self.state.pause is False:
xwindow.notify("✅ Thanks.")
return
window = XWindow()
@@ -228,12 +247,12 @@ class AntiDriftDaemon(ServiceInterface):
delay = int(self.config.enforce_delay_ms / 1000)
for i in range(delay, 0, -1):
await asyncio.sleep(1)
xwindow.notify(f"AntiDrift will minimize in {i}s.")
xwindow.notify(f"⚠️ AntiDrift will minimize in {i}s.")
if not window_is_blocked(self.state, silent=True):
xwindow.notify("We are gucci again.")
xwindow.notify("We are gucci again.")
return
window = XWindow()
xwindow.notify(f"Minimize {window.name[:30]}.")
xwindow.notify(f"Minimize {window.name[:30]}.")
window.minimize()
@@ -266,8 +285,7 @@ def window_is_blocked(state: State, silent: bool = False) -> bool:
for k in b.keywords:
if keyword_matches_window(k, window) and b.kill:
window.kill()
xwindow.notify(f"Kill for {k} on {b.name}.")
logging.warning(f"Kill for [red]{k}[/red] on [red]{b.name}[/red].")
xwindow.notify(f"Kill for {k} on {b.name}.")
return True
elif keyword_matches_window(k, window):
if not silent:
@@ -279,7 +297,7 @@ def window_is_blocked(state: State, silent: bool = False) -> bool:
return True
if not whiteblocks:
if not silent:
logging.debug("All non-blackblock windows are allowed.")
logging.debug(" All non-blackblock windows are allowed.")
return False
for w in whiteblocks:
for k in w.keywords:

View File

@@ -39,11 +39,11 @@ def evaluate(config: Config):
log_file = config.window_log_file
datapoints: List[Datapoint] = []
with open(log_file, 'r') as file:
with open(log_file, "r") as file:
reader = csv.reader(file)
for row in reader:
timestamp_str, title, tool, intention = row
if title != '':
if title != "":
timestamp = datetime.fromisoformat(timestamp_str)
datapoint = Datapoint(timestamp, title, tool, intention)
datapoints.append(datapoint)
@@ -53,7 +53,7 @@ def evaluate(config: Config):
prev_datapoint = None
prev_evaluation = None
for d in datapoints:
if d.title == '':
if d.title == "":
continue
# Get evaluation of current datapoint
@@ -79,7 +79,7 @@ def evaluate(config: Config):
def parse_result(result: str) -> Optional[Evaluation]:
try:
content = json.loads(result.strip())
return Evaluation(content['level'], content['reason'])
return Evaluation(content["level"], content["reason"])
except (ValueError, KeyError):
return None
@@ -115,7 +115,7 @@ def evaluate_datapoint(title, tool, intention) -> Optional[str]:
if r.status_code == 200:
response = r.json()
message_response = response["choices"][0]["message"]
return message_response['content']
return message_response["content"]
else:
xwindow.notify(f"Antidrift - GPT - Response error status code {r.status_code}")
return None

View File

@@ -1,165 +0,0 @@
def is_window_blocked(window_name: str, blocked: List[re.Pattern]) -> bool:
for b in blocked:
if b.findall(window_name):
return True
return False
def kill_sequence(blocked: List[re.Pattern]) -> None:
def to_display(name: str) -> str:
return name if len(name) < 30 else name[:30] + "..."
for count in range(5, 0, -1):
window_name, window_pid = get_active_window_name_and_pid()
if not is_window_blocked(window_name, blocked):
notify(f"[okay] {to_display(window_name)}")
return
notify(f"[kill {count}s] {to_display(window_name)}")
time.sleep(1)
p = psutil.Process(int(window_pid))
p.terminate()
def notify(message: str) -> None:
"""Notify user via the Xorg notify-send command."""
env = {**os.environ, "DBUS_SESSION_BUS_ADDRESS": "unix:path=/run/user/1000/bus"}
user = env.get("SUDO_USER", None)
if user is None:
cmd = ["notify-send", message]
else:
cmd = ["runuser", "-m", "-u", user, "notify-send", message]
subprocess.run(cmd, env=env)
def write_pid_file(config: Config) -> str:
p = psutil.Process()
pid_file = os.path.join(config.directory, f"{p.pid}.pid")
with open(pid_file, "w") as f:
f.write(str(p.pid))
return pid_file
def terminate_existing_blocker(config: Config) -> None:
this_pid = psutil.Process().pid
pid_files = [f for f in config.directory if f.endswith(".pid")]
for pid_file in pid_files:
pid = int(pid_file.replace(".pid", ""))
if this_pid == pid:
continue
try:
p = psutil.Process(pid)
p.terminate()
except psutil.NoSuchProcess:
print(f"Could not terminate {p.pid=}.")
pid_file = os.path.join(config.directory, pid_file)
try:
os.remove(pid_file)
except PermissionError:
print(f"Could not remove {pid_file=}.")
def load_blocked_patterns(config: Config) -> List[re.Pattern]:
blocked = []
for block in config.blocks:
for item in block.items:
s = "".join([block.prefix, item, block.postfix])
r = re.compile(s, re.IGNORECASE)
blocked.append(r)
return blocked
def init(config: Config) -> None:
terminate_existing_blocker(config)
write_pid_file(config)
def run_checks(config: Config) -> None:
# shutil.which("xdotool")
# That's what we do if anything goes wrong:
# xkill --all
pass
def kill_window_if_blocked(blocked: List[re.Pattern]) -> None:
window_name, window_pid = get_active_window_name_and_pid()
if is_window_blocked(window_name, blocked):
kill_sequence(blocked)
def is_process_active(process_name: str) -> bool:
for proc in psutil.process_iter():
if proc.name() == process_name:
return True
return False
def enforce_aw_commit():
def aw_commit_active():
return is_process_active("aw-commit")
def to_display(name: str) -> str:
return name if len(name) < 30 else name[:30] + "..."
if aw_commit_active():
return
for _ in range(10, 0, -1):
notify(f"[warning] aw-commit not running")
time.sleep(1)
if aw_commit_active():
return
if aw_commit_active():
return
window_name, window_pid = get_active_window_name_and_pid()
if window_name:
notify(f"[kill aw-commit not running] {to_display(window_name)}")
p = psutil.Process(int(window_pid))
p.terminate()
return enforce_aw_commit()
def load_window_names(config: Config) -> Set[str]:
window_names_file = os.path.join(config.directory, config.window_names)
if not os.path.isfile(window_names_file):
return set()
with open(window_names_file, "r") as f:
return {l.strip() for l in f.readlines()}
def write_window_names(config: Config, window_names: Set[str]) -> None:
window_names_file = os.path.join(config.directory, config.window_names)
window_names = "\n".join(sorted(list(window_names)))
with open(window_names_file, "w") as f:
f.write(window_names)
def main() -> None:
"""Run main_root as root except while debugging."""
config_path = "~/.config/aw-focus/config.json"
config = Config.load_config(config_path)
if config.start_as_user:
terminate_existing_blocker(config)
main_root(config)
if os.geteuid() == 0:
newpid = os.fork()
if newpid == 0:
main_root(config)
else:
cmd = ["sudo", config.aw_focus_cmd] + sys.argv[1:]
subprocess.Popen(cmd)
def main_root(config: Config) -> None:
init(config)
blocked = load_blocked_patterns(config)
while True:
time.sleep(config.sleep_time)
run_checks(config)
kill_window_if_blocked(blocked)
if config.enforce_aw_commit:
enforce_aw_commit()

View File

@@ -19,7 +19,9 @@ class XWindow:
self.keywords = list(re.findall(r"\w+", self.name.lower()))
def __repr__(self):
return f"<XWindow '{self.name[:20]}' '{self.cls[:20]}'>"
return (
f"<XWindow '{self.name[:20]}' '{self.cls[:20]}' active: {self.is_active()}>"
)
def _run(self, cmd) -> str:
cmd = ["xdotool"] + cmd
@@ -37,8 +39,10 @@ class XWindow:
def kill(self):
self._run(["windowkill", self.window])
def is_active(self):
return True if self.name else False
def is_active(self) -> bool:
current_desktop = self._run(["get_desktop"])
window_desktop = self._run(["get_desktop_for_window", self.window])
return True if self.name and current_desktop == window_desktop else False
def notify(message: str) -> None:

29
main.py
View File

@@ -2,11 +2,9 @@ import logging
import shutil
import sys
import os
import signal
import subprocess
import argparse
import psutil
import rich
import asyncio
from rich.logging import RichHandler
from pathlib import Path
@@ -15,20 +13,22 @@ import antidrift.client
from antidrift.daemon import AntiDriftDaemon
from antidrift.config import Config
from dbus.mainloop.glib import DBusGMainLoop
DBusGMainLoop(set_as_default=True)
signal.signal(signal.SIGINT, signal.SIG_DFL)
def get_args():
parser = argparse.ArgumentParser(description="AntiDrift CLI.")
parser.add_argument("--daemon", action="store_true", help="run daemon")
parser.add_argument("--evaluate", action="store_true", help="evaluate day")
parser.add_argument("--intention", metavar="intention", help="set intention", default=None, type=str)
parser.add_argument(
"--intention", metavar="intention", help="set intention", default=None, type=str
)
parser.add_argument(
"--block", metavar="block", help="add to block", default=None, type=str
)
parser.add_argument("--pause", action="store_true", help="pause antidrift")
parser.add_argument("--schedule", metavar="blackblock", help="schedule blackblock")
parser.add_argument("--start", metavar="whiteblock", nargs="+", help="start whiteblocks")
parser.add_argument(
"--start", metavar="whiteblock", nargs="+", help="start whiteblocks"
)
parser.add_argument("--status", action="store_true", help="get status from daemon")
parser.add_argument("--stop", action="store_true", help="stop session")
parser.add_argument("--tailf", action="store_true", help="tail -f log file")
@@ -82,10 +82,13 @@ def check_for_xdotool():
def kill_existing_antidrift():
current_pid = os.getpid()
for proc in psutil.process_iter(['pid', 'name', 'exe']):
if proc.info['name'] == '/usr/bin/antidrift' or proc.info['exe'] == '/usr/bin/antidrift':
if proc.info['pid'] == current_pid:
continue # don't kill ourselves
for proc in psutil.process_iter(["pid", "name", "exe"]):
if (
proc.info["name"] == "/usr/bin/antidrift"
or proc.info["exe"] == "/usr/bin/antidrift"
):
if proc.info["pid"] == current_pid:
continue # don't this process
try:
proc.kill()
except psutil.AccessDenied: