I'm trying to run GPT-Neo-1.3B locally in my server and call it via an API. I installed everything and when I call it with an API I'm getting the same answer I'm asking as response. Below is my code and my request/response. I'm using python/Flask
Any idea what is wrong with my code please?
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
import torch
from flask import Flask, request, jsonify, abort
# Load the model and tokenizer
model = GPTNeoForCausalLM.from_pretrained("./gpt_neo_1.3B")
tokenizer = GPT2Tokenizer.from_pretrained("./gpt_neo_1.3B")
# Define a pad_token if it's not already defined
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Resize the token embeddings with mean_resizing=False
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
# Set the pad_token_id to the new padding token (e.g., '[PAD]')
model.config.pad_token_id = tokenizer.pad_token_id
# Print pad_token_id and eos_token_id to ensure proper configuration
print("Pad token ID:", tokenizer.pad_token_id)
print("EOS token ID:", tokenizer.eos_token_id)
model.eval()
# Set up Flask app
app = Flask(__name__)
@app.route('/generate', methods=['POST'])
def generate_text():
# Get the input text from the request
data = request.json
input_text = data.get('text', '')
if not input_text:
return jsonify({"error": "No text provided"}), 400 # Bad Request
# Tokenize input with padding, truncation, and attention mask
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
# Generate output with attention_mask, explicitly setting pad_token_id
with torch.no_grad():
outputs = model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=100, # Increase max_length
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True, # Enable sampling to diversify output
top_k=50, # Use top-k sampling
top_p=0.95, # Use nucleus sampling
temperature=0.7 # Control creativity of the output
)
# Decode and return output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({"generated_text": generated_text})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
Request:
curl -X POST http://localhost:5000/generate -H "Content-Type: application/json" -d '{"text": "What time is it in japan"}'
Response:
{"generated_text":"What time is it in japanese?\n\n"}
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744394953a4572093.html
评论列表(0条)