Initial commit.
This commit is contained in:
55
main.py
Normal file
55
main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
from typing import ClassVar
|
||||
from datasets import Dataset
|
||||
|
||||
class AiTrainerParts:
|
||||
DEFAULT_MODEL_NAME: ClassVar[str] = "deepseek-ai/DeepSeek-V3"
|
||||
|
||||
model: ClassVar[AutoModelForCausalLM]
|
||||
model_name: ClassVar[str]
|
||||
tokenizer: ClassVar[AutoTokenizer]
|
||||
trainer: ClassVar[Trainer]
|
||||
|
||||
@classmethod
|
||||
def make_default(cls):
|
||||
return cls(cls.DEFAULT_MODEL_NAME)
|
||||
|
||||
def download_model_and_tokenizer(self):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
def save_model_and_tokenizer(self):
|
||||
path = f"./pretrained_models/{self.model_name}"
|
||||
self.model.save_pretrained(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
@staticmethod
|
||||
def tokenizer_preprocess_function(examples):
|
||||
return tokenizer(examples["input"], truncation=True, padding="max_length", max_length=512)
|
||||
def tokenize_training_data():
|
||||
data = Dataset.from_dict({
|
||||
"input": ["What is our company's mission?", "Who is the CEO?"],
|
||||
"output": ["Our mission is to...", "The CEO is John Doe."]
|
||||
})
|
||||
tokenized_data = data.map(AiTrainerParts.tokenizer_preprocess_function, batched=True)
|
||||
return tokenized_data
|
||||
|
||||
def fine_tune_model(self):
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./deepseek-v3-finetuned",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=3,
|
||||
logging_dir="./logs",
|
||||
save_steps=10_000,
|
||||
save_total_limit=2,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
train_dataset=self.training_data,
|
||||
)
|
||||
trainer.train()
|
||||
def save_fine_tuned_model(self):
|
||||
self.trainer.save_model("./deepseek-v3-finetuned")
|
||||
self.tokenizer.save_pretrained("./deepseek-v3-finetuned")
|
||||
Reference in New Issue
Block a user