diff --git a/src/fzf.py b/src/fzf.py index 120e2c7..86c2c11 100644 --- a/src/fzf.py +++ b/src/fzf.py @@ -1,126 +1,31 @@ -from __future__ import print_function - import errno -import os.path import subprocess import sys -from pkg_resources import resource_exists, resource_filename - -__all__ = 'BUNDLED_EXECUTABLE', 'iterfzf' EXECUTABLE_NAME = 'fzf.exe' if sys.platform == 'win32' else 'fzf' -BUNDLED_EXECUTABLE = ( - resource_filename(__name__, EXECUTABLE_NAME) - if resource_exists(__name__, EXECUTABLE_NAME) - else ( - os.path.join(os.path.dirname(__file__), EXECUTABLE_NAME) - if os.path.isfile( - os.path.join(os.path.dirname(__file__), EXECUTABLE_NAME) - ) - else None - ) -) -def iterfzf( - # CHECK: When the signature changes, __init__.pyi file should also change. - iterable, - # Search mode: - extended=True, exact=False, case_sensitive=None, - # Interface: - multi=False, mouse=True, print_query=False, - # Layout: - prompt='> ', - ansi=None, - preview=None, - # Misc: - query='', encoding=None, executable=BUNDLED_EXECUTABLE or EXECUTABLE_NAME -): - cmd = [executable, '--no-sort', '--prompt=' + prompt] - cmd = [executable, '--prompt=' + prompt] - if not extended: - cmd.append('--no-extended') - if case_sensitive is not None: - cmd.append('+i' if case_sensitive else '-i') - if exact: - cmd.append('--exact') - if multi: - cmd.append('--multi') - if not mouse: - cmd.append('--no-mouse') - if print_query: - cmd.append('--print-query') - if query: - cmd.append('--query=' + query) - if preview: - cmd.append('--preview=' + preview) - if ansi: - cmd.append('--ansi') - encoding = encoding or sys.getdefaultencoding() - proc = None - stdin = None - byte = None - lf = u'\n' - cr = u'\r' - for line in iterable: - if byte is None: - byte = isinstance(line, bytes) - if byte: - lf = b'\n' - cr = b'\r' - elif isinstance(line, bytes) is not byte: - raise ValueError( - 'element values must be all byte strings or all ' - 'unicode strings, not mixed of them: ' + repr(line) - ) - if lf in line or cr in line: - raise ValueError(r"element values must not contain CR({1!r})/" - r"LF({2!r}): {0!r}".format(line, cr, lf)) - if proc is None: - proc = subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=None - ) - stdin = proc.stdin - if not byte: - line = line.encode(encoding) - try: - stdin.write(line + b'\n') - stdin.flush() - except IOError as e: - if e.errno != errno.EPIPE and errno.EPIPE != 32: - raise - break - stdin.close() - if proc is None or proc.wait() not in [0, 1]: - if print_query: - return None, None - else: - return None +def iterfzf(iterable, prompt='> '): + cmd = [EXECUTABLE_NAME, '--prompt=' + prompt] + encoding = sys.getdefaultencoding() + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=None) + if proc.stdin is None: + return None try: - stdin.close() + lines = "\n".join(iterable) + proc.stdin.write(lines.encode('utf-8')) + proc.stdin.close() except IOError as e: if e.errno != errno.EPIPE and errno.EPIPE != 32: raise - stdout = proc.stdout - decode = (lambda b: b) if byte else (lambda t: t.decode(encoding)) - output = [decode(ln.strip(b'\r\n\0')) for ln in iter(stdout.readline, b'')] - if print_query: - try: - if multi: - return output[0], output[1:] - else: - return output[0], output[1] - except IndexError: - return output[0], None - else: - if multi: - return output - else: - try: - return output[0] - except IndexError: - return None + if proc is None or proc.wait() not in [0, 1]: + return None + if proc.stdout is None: + return None + decode = lambda t: t.decode(encoding) + output = [decode(ln.strip(b'\r\n\0')) for ln in iter(proc.stdout.readline, b'')] + try: + return output[0] + except IndexError: + return None diff --git a/src/predict.py b/src/predict.py index 69e5ad3..8482def 100644 --- a/src/predict.py +++ b/src/predict.py @@ -5,14 +5,36 @@ from typing import List def get_categories(transactions: List[Transaction]) -> List[str]: categories = set([t.account2 for t in transactions]) - categories.discard(UNKNOWN_CATEGORY) + categories.add(UNKNOWN_CATEGORY) return list(categories) +def get_sort_categories(): + def sort_categories(row: str, categories: List[str]): + if learn is None: + return + _, _, probs = learn.predict(row) + cat_to_prob = dict(zip(learn.dls.vocab[1],probs.tolist())) + categories.sort(key=lambda c: cat_to_prob[c] if c in cat_to_prob else 0.0) + + learn = None + try: + from fastai.text.all import load_learner + learn = load_learner("ldg.pkl") + except ModuleNotFoundError: + pass + except FileNotFoundError: + pass + + return sort_categories + + def add_account2(transactions: List[Transaction]): categories = get_categories(transactions) unmapped_transactions = filter(lambda t: t.account2 == UNKNOWN_CATEGORY, transactions) + sort_categories = get_sort_categories() for t in unmapped_transactions: + sort_categories(t.row, categories) add_account2_interactive(t, categories)