36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
from typing import ClassVar
|
|
|
|
from DeepSeek_PARTS.helpers.ai_helper import Ai_Helper
|
|
|
|
class Parts_Ai_Assistant:
|
|
DEFAULT_MODEL_NAME: ClassVar[str] = "deepseek-ai/DeepSeek-V3"
|
|
|
|
generator: ClassVar[pipeline]
|
|
model: ClassVar[AutoModelForCausalLM]
|
|
model_name: ClassVar[str]
|
|
tokenizer: ClassVar[AutoTokenizer]
|
|
|
|
def __init__(self, model_name):
|
|
self.model_name = model_name
|
|
self.load_pretrained_model()
|
|
|
|
@classmethod
|
|
def make_default(cls):
|
|
return cls(cls.DEFAULT_MODEL_NAME)
|
|
|
|
def load_pretrained_model(self):
|
|
path = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = True)
|
|
self.model = AutoModelForCausalLM.from_pretrained(path)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
self.generator = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
|
|
|
|
def get_response(self, prompt, conversation_history = []):
|
|
complete_prompt = "\n".join(conversation_history + [prompt])
|
|
return self.generator(complete_prompt)
|
|
|
|
if __name__ == "__main__":
|
|
assistant = Parts_Ai_Assistant.make_default()
|
|
response = assistant.get_response("What is our company's mission?")
|
|
print(response) |