generated from felixm/defaultpy
Implement category prediction and clean up fzf
This commit is contained in:
133
src/fzf.py
133
src/fzf.py
@@ -1,126 +1,31 @@
|
|||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
import os.path
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pkg_resources import resource_exists, resource_filename
|
|
||||||
|
|
||||||
__all__ = 'BUNDLED_EXECUTABLE', 'iterfzf'
|
|
||||||
|
|
||||||
EXECUTABLE_NAME = 'fzf.exe' if sys.platform == 'win32' else 'fzf'
|
EXECUTABLE_NAME = 'fzf.exe' if sys.platform == 'win32' else 'fzf'
|
||||||
BUNDLED_EXECUTABLE = (
|
|
||||||
resource_filename(__name__, EXECUTABLE_NAME)
|
|
||||||
if resource_exists(__name__, EXECUTABLE_NAME)
|
|
||||||
else (
|
|
||||||
os.path.join(os.path.dirname(__file__), EXECUTABLE_NAME)
|
|
||||||
if os.path.isfile(
|
|
||||||
os.path.join(os.path.dirname(__file__), EXECUTABLE_NAME)
|
|
||||||
)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def iterfzf(
|
def iterfzf(iterable, prompt='> '):
|
||||||
# CHECK: When the signature changes, __init__.pyi file should also change.
|
cmd = [EXECUTABLE_NAME, '--prompt=' + prompt]
|
||||||
iterable,
|
encoding = sys.getdefaultencoding()
|
||||||
# Search mode:
|
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=None)
|
||||||
extended=True, exact=False, case_sensitive=None,
|
if proc.stdin is None:
|
||||||
# Interface:
|
return None
|
||||||
multi=False, mouse=True, print_query=False,
|
|
||||||
# Layout:
|
|
||||||
prompt='> ',
|
|
||||||
ansi=None,
|
|
||||||
preview=None,
|
|
||||||
# Misc:
|
|
||||||
query='', encoding=None, executable=BUNDLED_EXECUTABLE or EXECUTABLE_NAME
|
|
||||||
):
|
|
||||||
cmd = [executable, '--no-sort', '--prompt=' + prompt]
|
|
||||||
cmd = [executable, '--prompt=' + prompt]
|
|
||||||
if not extended:
|
|
||||||
cmd.append('--no-extended')
|
|
||||||
if case_sensitive is not None:
|
|
||||||
cmd.append('+i' if case_sensitive else '-i')
|
|
||||||
if exact:
|
|
||||||
cmd.append('--exact')
|
|
||||||
if multi:
|
|
||||||
cmd.append('--multi')
|
|
||||||
if not mouse:
|
|
||||||
cmd.append('--no-mouse')
|
|
||||||
if print_query:
|
|
||||||
cmd.append('--print-query')
|
|
||||||
if query:
|
|
||||||
cmd.append('--query=' + query)
|
|
||||||
if preview:
|
|
||||||
cmd.append('--preview=' + preview)
|
|
||||||
if ansi:
|
|
||||||
cmd.append('--ansi')
|
|
||||||
encoding = encoding or sys.getdefaultencoding()
|
|
||||||
proc = None
|
|
||||||
stdin = None
|
|
||||||
byte = None
|
|
||||||
lf = u'\n'
|
|
||||||
cr = u'\r'
|
|
||||||
for line in iterable:
|
|
||||||
if byte is None:
|
|
||||||
byte = isinstance(line, bytes)
|
|
||||||
if byte:
|
|
||||||
lf = b'\n'
|
|
||||||
cr = b'\r'
|
|
||||||
elif isinstance(line, bytes) is not byte:
|
|
||||||
raise ValueError(
|
|
||||||
'element values must be all byte strings or all '
|
|
||||||
'unicode strings, not mixed of them: ' + repr(line)
|
|
||||||
)
|
|
||||||
if lf in line or cr in line:
|
|
||||||
raise ValueError(r"element values must not contain CR({1!r})/"
|
|
||||||
r"LF({2!r}): {0!r}".format(line, cr, lf))
|
|
||||||
if proc is None:
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=None
|
|
||||||
)
|
|
||||||
stdin = proc.stdin
|
|
||||||
if not byte:
|
|
||||||
line = line.encode(encoding)
|
|
||||||
try:
|
|
||||||
stdin.write(line + b'\n')
|
|
||||||
stdin.flush()
|
|
||||||
except IOError as e:
|
|
||||||
if e.errno != errno.EPIPE and errno.EPIPE != 32:
|
|
||||||
raise
|
|
||||||
break
|
|
||||||
stdin.close()
|
|
||||||
if proc is None or proc.wait() not in [0, 1]:
|
|
||||||
if print_query:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
try:
|
try:
|
||||||
stdin.close()
|
lines = "\n".join(iterable)
|
||||||
|
proc.stdin.write(lines.encode('utf-8'))
|
||||||
|
proc.stdin.close()
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
if e.errno != errno.EPIPE and errno.EPIPE != 32:
|
if e.errno != errno.EPIPE and errno.EPIPE != 32:
|
||||||
raise
|
raise
|
||||||
stdout = proc.stdout
|
if proc is None or proc.wait() not in [0, 1]:
|
||||||
decode = (lambda b: b) if byte else (lambda t: t.decode(encoding))
|
return None
|
||||||
output = [decode(ln.strip(b'\r\n\0')) for ln in iter(stdout.readline, b'')]
|
if proc.stdout is None:
|
||||||
if print_query:
|
return None
|
||||||
try:
|
decode = lambda t: t.decode(encoding)
|
||||||
if multi:
|
output = [decode(ln.strip(b'\r\n\0')) for ln in iter(proc.stdout.readline, b'')]
|
||||||
return output[0], output[1:]
|
try:
|
||||||
else:
|
return output[0]
|
||||||
return output[0], output[1]
|
except IndexError:
|
||||||
except IndexError:
|
return None
|
||||||
return output[0], None
|
|
||||||
else:
|
|
||||||
if multi:
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return output[0]
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -5,14 +5,36 @@ from typing import List
|
|||||||
|
|
||||||
def get_categories(transactions: List[Transaction]) -> List[str]:
|
def get_categories(transactions: List[Transaction]) -> List[str]:
|
||||||
categories = set([t.account2 for t in transactions])
|
categories = set([t.account2 for t in transactions])
|
||||||
categories.discard(UNKNOWN_CATEGORY)
|
categories.add(UNKNOWN_CATEGORY)
|
||||||
return list(categories)
|
return list(categories)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sort_categories():
|
||||||
|
def sort_categories(row: str, categories: List[str]):
|
||||||
|
if learn is None:
|
||||||
|
return
|
||||||
|
_, _, probs = learn.predict(row)
|
||||||
|
cat_to_prob = dict(zip(learn.dls.vocab[1],probs.tolist()))
|
||||||
|
categories.sort(key=lambda c: cat_to_prob[c] if c in cat_to_prob else 0.0)
|
||||||
|
|
||||||
|
learn = None
|
||||||
|
try:
|
||||||
|
from fastai.text.all import load_learner
|
||||||
|
learn = load_learner("ldg.pkl")
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sort_categories
|
||||||
|
|
||||||
|
|
||||||
def add_account2(transactions: List[Transaction]):
|
def add_account2(transactions: List[Transaction]):
|
||||||
categories = get_categories(transactions)
|
categories = get_categories(transactions)
|
||||||
unmapped_transactions = filter(lambda t: t.account2 == UNKNOWN_CATEGORY, transactions)
|
unmapped_transactions = filter(lambda t: t.account2 == UNKNOWN_CATEGORY, transactions)
|
||||||
|
sort_categories = get_sort_categories()
|
||||||
for t in unmapped_transactions:
|
for t in unmapped_transactions:
|
||||||
|
sort_categories(t.row, categories)
|
||||||
add_account2_interactive(t, categories)
|
add_account2_interactive(t, categories)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user