Start to implement prediction

This commit is contained in:
2025-03-15 19:46:15 -04:00
parent 3e4a284692
commit c0b6e64d7f
3 changed files with 20 additions and 5 deletions

View File

@@ -1,8 +1,10 @@
import logging
import sys
from rich.logging import RichHandler
from toldg.process import process_csv_files, process_ldg_files
from toldg.train import train
from toldg.utils import load_config, remove_if_exists, write_meta
@@ -18,10 +20,13 @@ def init_logging():
def main():
init_logging()
config = load_config()
remove_if_exists(config.output_file)
write_meta(config)
process_ldg_files(config)
process_csv_files(config)
if len(sys.argv) > 2 and sys.argv[2] == "train":
train(config)
else:
remove_if_exists(config.output_file)
write_meta(config)
process_ldg_files(config)
process_csv_files(config)
if __name__ == "__main__":

View File

@@ -82,7 +82,7 @@ def apply_mappings(transactions: List[Transaction], mappings: Dict[str, Mapping]
assert mapping.count == 0, f"{mapping} was not used as often as expected!"
def process_csv_files(config: Config):
def process_csv_files(config: Config) -> List[Transaction]:
csv_files = toldg.utils.get_csv_files(config.input_directory)
transactions = []
for csv_file in csv_files:
@@ -95,3 +95,4 @@ def process_csv_files(config: Config):
toldg.predict.add_account2(transactions, config.categories)
toldg.utils.write_mappings(transactions, config.mappings_file)
toldg.write.render_to_file(transactions, config)
return transactions

9
src/toldg/train.py Normal file
View File

@@ -0,0 +1,9 @@
from toldg.models import Config, CsvConfig, Mapping, Transaction
from toldg.process import process_csv_files
def train(config: Config):
print("[train] start")
transactions = process_csv_files(config)
for t in transactions:
pass