Initial commit.
This commit is contained in:
95
parts_ai_trainer.py
Normal file
95
parts_ai_trainer.py
Normal file
@@ -0,0 +1,95 @@
|
||||
|
||||
from datasets import Dataset
|
||||
import json
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
from typing import ClassVar
|
||||
|
||||
from datastores.parts_ai_datastore import Parts_Ai_DataStore
|
||||
from helpers.ai_helper import Ai_Helper
|
||||
|
||||
class Parts_Ai_Trainer:
|
||||
DEFAULT_MODEL_NAME: ClassVar[str] = "deepseek-ai/DeepSeek-V3"
|
||||
TRAINING_DATA_DIRECTORY: ClassVar[str] = "docs/training_data"
|
||||
|
||||
model: ClassVar[AutoModelForCausalLM]
|
||||
model_name: ClassVar[str]
|
||||
tokenizer: ClassVar[AutoTokenizer]
|
||||
trainer: ClassVar[Trainer]
|
||||
training_data: ClassVar[object]
|
||||
|
||||
def __init__(self, model_name):
|
||||
self.model_name = model_name
|
||||
self.download_model_and_tokenizer()
|
||||
@classmethod
|
||||
def make_default(cls):
|
||||
return cls(cls.DEFAULT_MODEL_NAME)
|
||||
|
||||
def download_model_and_tokenizer(self):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
def save_model_and_tokenizer(self):
|
||||
path = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = False)
|
||||
self.model.save_pretrained(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
def train_and_save_model(self):
|
||||
self.load_and_preprocess_training_data()
|
||||
self.fine_tune_model()
|
||||
self.save_fine_tuned_model()
|
||||
def load_and_preprocess_training_data(self):
|
||||
self.load_training_data()
|
||||
self.preprocess_data()
|
||||
self.tokenize_preprocessed_data()
|
||||
def load_training_data(self):
|
||||
datastore = Parts_Ai_DataStore.make_default()
|
||||
return datastore.load_training_data()
|
||||
def preprocess_data(self):
|
||||
processed_data = []
|
||||
for item in self.training_data:
|
||||
if item["type"] == "json":
|
||||
for key, value in item["content"].items():
|
||||
processed_data.append({
|
||||
"input": key
|
||||
, "output": value
|
||||
})
|
||||
elif item["type"] == "txt":
|
||||
first_line, contents = item["content"].split("\n", 1) if "\n" in item["content"] else (item["content"], "")
|
||||
processed_data.append({
|
||||
"input": first_line
|
||||
, "output": contents
|
||||
})
|
||||
self.training_data = Dataset.from_dict({
|
||||
"input": [item["input"] for item in processed_data],
|
||||
"output": [item["output"] for item in processed_data]
|
||||
})
|
||||
def tokenize_preprocessed_data(self):
|
||||
self.training_data = self.training_data.map(self.tokenizer_preprocess_function, batched=True)
|
||||
@staticmethod
|
||||
def tokenizer_preprocess_function(self, examples):
|
||||
return self.tokenizer(examples["input"], truncation = True, padding = "max_length", max_length = 512)
|
||||
def fine_tune_model(self):
|
||||
training_args = TrainingArguments(
|
||||
output_dir = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = True),
|
||||
per_device_train_batch_size = 4,
|
||||
num_train_epochs = 3,
|
||||
logging_dir = "./logs",
|
||||
save_steps = 10_000,
|
||||
save_total_limit = 2,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.training_data,
|
||||
)
|
||||
trainer.train()
|
||||
def save_fine_tuned_model(self):
|
||||
path = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = True)
|
||||
self.trainer.save_model(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = Parts_Ai_Trainer.make_default()
|
||||
trainer.train_and_save_model()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user