generated from felixm/defaultpy
Implement category prediction and clean up fzf
This commit is contained in:
@@ -5,14 +5,36 @@ from typing import List
|
||||
|
||||
def get_categories(transactions: List[Transaction]) -> List[str]:
|
||||
categories = set([t.account2 for t in transactions])
|
||||
categories.discard(UNKNOWN_CATEGORY)
|
||||
categories.add(UNKNOWN_CATEGORY)
|
||||
return list(categories)
|
||||
|
||||
|
||||
def get_sort_categories():
|
||||
def sort_categories(row: str, categories: List[str]):
|
||||
if learn is None:
|
||||
return
|
||||
_, _, probs = learn.predict(row)
|
||||
cat_to_prob = dict(zip(learn.dls.vocab[1],probs.tolist()))
|
||||
categories.sort(key=lambda c: cat_to_prob[c] if c in cat_to_prob else 0.0)
|
||||
|
||||
learn = None
|
||||
try:
|
||||
from fastai.text.all import load_learner
|
||||
learn = load_learner("ldg.pkl")
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return sort_categories
|
||||
|
||||
|
||||
def add_account2(transactions: List[Transaction]):
|
||||
categories = get_categories(transactions)
|
||||
unmapped_transactions = filter(lambda t: t.account2 == UNKNOWN_CATEGORY, transactions)
|
||||
sort_categories = get_sort_categories()
|
||||
for t in unmapped_transactions:
|
||||
sort_categories(t.row, categories)
|
||||
add_account2_interactive(t, categories)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user