Files
DeepSeek_PARTS/parts_ai_trainer.py
2025-02-20 11:56:58 +00:00

96 lines
3.6 KiB
Python

from datasets import Dataset
import json
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from typing import ClassVar
from datastores.parts_ai_datastore import Parts_Ai_DataStore
from helpers.ai_helper import Ai_Helper
class Parts_Ai_Trainer:
DEFAULT_MODEL_NAME: ClassVar[str] = "deepseek-ai/DeepSeek-V3"
TRAINING_DATA_DIRECTORY: ClassVar[str] = "docs/training_data"
model: ClassVar[AutoModelForCausalLM]
model_name: ClassVar[str]
tokenizer: ClassVar[AutoTokenizer]
trainer: ClassVar[Trainer]
training_data: ClassVar[object]
def __init__(self, model_name):
self.model_name = model_name
self.download_model_and_tokenizer()
@classmethod
def make_default(cls):
return cls(cls.DEFAULT_MODEL_NAME)
def download_model_and_tokenizer(self):
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
def save_model_and_tokenizer(self):
path = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = False)
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)
def train_and_save_model(self):
self.load_and_preprocess_training_data()
self.fine_tune_model()
self.save_fine_tuned_model()
def load_and_preprocess_training_data(self):
self.load_training_data()
self.preprocess_data()
self.tokenize_preprocessed_data()
def load_training_data(self):
datastore = Parts_Ai_DataStore.make_default()
return datastore.load_training_data()
def preprocess_data(self):
processed_data = []
for item in self.training_data:
if item["type"] == "json":
for key, value in item["content"].items():
processed_data.append({
"input": key
, "output": value
})
elif item["type"] == "txt":
first_line, contents = item["content"].split("\n", 1) if "\n" in item["content"] else (item["content"], "")
processed_data.append({
"input": first_line
, "output": contents
})
self.training_data = Dataset.from_dict({
"input": [item["input"] for item in processed_data],
"output": [item["output"] for item in processed_data]
})
def tokenize_preprocessed_data(self):
self.training_data = self.training_data.map(self.tokenizer_preprocess_function, batched=True)
@staticmethod
def tokenizer_preprocess_function(self, examples):
return self.tokenizer(examples["input"], truncation = True, padding = "max_length", max_length = 512)
def fine_tune_model(self):
training_args = TrainingArguments(
output_dir = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = True),
per_device_train_batch_size = 4,
num_train_epochs = 3,
logging_dir = "./logs",
save_steps = 10_000,
save_total_limit = 2,
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=self.training_data,
)
trainer.train()
def save_fine_tuned_model(self):
path = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = True)
self.trainer.save_model(path)
self.tokenizer.save_pretrained(path)
if __name__ == "__main__":
trainer = Parts_Ai_Trainer.make_default()
trainer.train_and_save_model()