55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
|
from typing import ClassVar
|
|
from datasets import Dataset
|
|
|
|
class AiTrainerParts:
|
|
DEFAULT_MODEL_NAME: ClassVar[str] = "deepseek-ai/DeepSeek-V3"
|
|
|
|
model: ClassVar[AutoModelForCausalLM]
|
|
model_name: ClassVar[str]
|
|
tokenizer: ClassVar[AutoTokenizer]
|
|
trainer: ClassVar[Trainer]
|
|
|
|
@classmethod
|
|
def make_default(cls):
|
|
return cls(cls.DEFAULT_MODEL_NAME)
|
|
|
|
def download_model_and_tokenizer(self):
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
|
def save_model_and_tokenizer(self):
|
|
path = f"./pretrained_models/{self.model_name}"
|
|
self.model.save_pretrained(path)
|
|
self.tokenizer.save_pretrained(path)
|
|
|
|
@staticmethod
|
|
def tokenizer_preprocess_function(examples):
|
|
return tokenizer(examples["input"], truncation=True, padding="max_length", max_length=512)
|
|
def tokenize_training_data():
|
|
data = Dataset.from_dict({
|
|
"input": ["What is our company's mission?", "Who is the CEO?"],
|
|
"output": ["Our mission is to...", "The CEO is John Doe."]
|
|
})
|
|
tokenized_data = data.map(AiTrainerParts.tokenizer_preprocess_function, batched=True)
|
|
return tokenized_data
|
|
|
|
def fine_tune_model(self):
|
|
training_args = TrainingArguments(
|
|
output_dir="./deepseek-v3-finetuned",
|
|
per_device_train_batch_size=4,
|
|
num_train_epochs=3,
|
|
logging_dir="./logs",
|
|
save_steps=10_000,
|
|
save_total_limit=2,
|
|
)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.training_data,
|
|
)
|
|
trainer.train()
|
|
def save_fine_tuned_model(self):
|
|
self.trainer.save_model("./deepseek-v3-finetuned")
|
|
self.tokenizer.save_pretrained("./deepseek-v3-finetuned") |