From c0b6e64d7f00ff67e49b366c7bf534b1e3d9d74e Mon Sep 17 00:00:00 2001 From: Felix Martin Date: Sat, 15 Mar 2025 19:46:15 -0400 Subject: [PATCH] Start to implement prediction --- src/toldg/__main__.py | 13 +++++++++---- src/toldg/process.py | 3 ++- src/toldg/train.py | 9 +++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) create mode 100644 src/toldg/train.py diff --git a/src/toldg/__main__.py b/src/toldg/__main__.py index 21c6801..4cf4f20 100644 --- a/src/toldg/__main__.py +++ b/src/toldg/__main__.py @@ -1,8 +1,10 @@ import logging +import sys from rich.logging import RichHandler from toldg.process import process_csv_files, process_ldg_files +from toldg.train import train from toldg.utils import load_config, remove_if_exists, write_meta @@ -18,10 +20,13 @@ def init_logging(): def main(): init_logging() config = load_config() - remove_if_exists(config.output_file) - write_meta(config) - process_ldg_files(config) - process_csv_files(config) + if len(sys.argv) > 2 and sys.argv[2] == "train": + train(config) + else: + remove_if_exists(config.output_file) + write_meta(config) + process_ldg_files(config) + process_csv_files(config) if __name__ == "__main__": diff --git a/src/toldg/process.py b/src/toldg/process.py index 40b3fd0..8e8b3be 100644 --- a/src/toldg/process.py +++ b/src/toldg/process.py @@ -82,7 +82,7 @@ def apply_mappings(transactions: List[Transaction], mappings: Dict[str, Mapping] assert mapping.count == 0, f"{mapping} was not used as often as expected!" -def process_csv_files(config: Config): +def process_csv_files(config: Config) -> List[Transaction]: csv_files = toldg.utils.get_csv_files(config.input_directory) transactions = [] for csv_file in csv_files: @@ -95,3 +95,4 @@ def process_csv_files(config: Config): toldg.predict.add_account2(transactions, config.categories) toldg.utils.write_mappings(transactions, config.mappings_file) toldg.write.render_to_file(transactions, config) + return transactions diff --git a/src/toldg/train.py b/src/toldg/train.py new file mode 100644 index 0000000..e829fde --- /dev/null +++ b/src/toldg/train.py @@ -0,0 +1,9 @@ +from toldg.models import Config, CsvConfig, Mapping, Transaction +from toldg.process import process_csv_files + + +def train(config: Config): + print("[train] start") + transactions = process_csv_files(config) + for t in transactions: + pass