Initial commit.
This commit is contained in:
36
parts_ai_assistant.py
Normal file
36
parts_ai_assistant.py
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
from typing import ClassVar
|
||||
|
||||
from DeepSeek_PARTS.helpers.ai_helper import Ai_Helper
|
||||
|
||||
class Parts_Ai_Assistant:
|
||||
DEFAULT_MODEL_NAME: ClassVar[str] = "deepseek-ai/DeepSeek-V3"
|
||||
|
||||
generator: ClassVar[pipeline]
|
||||
model: ClassVar[AutoModelForCausalLM]
|
||||
model_name: ClassVar[str]
|
||||
tokenizer: ClassVar[AutoTokenizer]
|
||||
|
||||
def __init__(self, model_name):
|
||||
self.model_name = model_name
|
||||
self.load_pretrained_model()
|
||||
|
||||
@classmethod
|
||||
def make_default(cls):
|
||||
return cls(cls.DEFAULT_MODEL_NAME)
|
||||
|
||||
def load_pretrained_model(self):
|
||||
path = Ai_Helper.get_pretrained_model_path(self.model_name, is_fine_tuned = True)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.generator = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
|
||||
|
||||
def get_response(self, prompt, conversation_history = []):
|
||||
complete_prompt = "\n".join(conversation_history + [prompt])
|
||||
return self.generator(complete_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
assistant = Parts_Ai_Assistant.make_default()
|
||||
response = assistant.get_response("What is our company's mission?")
|
||||
print(response)
|
||||
Reference in New Issue
Block a user