python - GPT-Neo-1.3B not giving answers when called via API - Stack Overflow

I'm trying to run GPT-Neo-1.3B locally in my server and call it via an API.I installed everything

I'm trying to run GPT-Neo-1.3B locally in my server and call it via an API. I installed everything and when I call it with an API I'm getting the same answer I'm asking as response. Below is my code and my request/response. I'm using python/Flask

Any idea what is wrong with my code please?

from transformers import GPTNeoForCausalLM, GPT2Tokenizer
import torch
from flask import Flask, request, jsonify, abort

# Load the model and tokenizer
model = GPTNeoForCausalLM.from_pretrained("./gpt_neo_1.3B")
tokenizer = GPT2Tokenizer.from_pretrained("./gpt_neo_1.3B")

# Define a pad_token if it's not already defined
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Resize the token embeddings with mean_resizing=False
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)

# Set the pad_token_id to the new padding token (e.g., '[PAD]')
model.config.pad_token_id = tokenizer.pad_token_id

# Print pad_token_id and eos_token_id to ensure proper configuration
print("Pad token ID:", tokenizer.pad_token_id)
print("EOS token ID:", tokenizer.eos_token_id)

model.eval()

# Set up Flask app
app = Flask(__name__)

@app.route('/generate', methods=['POST'])
def generate_text():
    
    # Get the input text from the request
    data = request.json
    input_text = data.get('text', '')
    
    if not input_text:
        return jsonify({"error": "No text provided"}), 400  # Bad Request

    # Tokenize input with padding, truncation, and attention mask
    inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)

    # Generate output with attention_mask, explicitly setting pad_token_id
    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'], 
            attention_mask=inputs['attention_mask'],
            max_length=100,  # Increase max_length
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,  # Enable sampling to diversify output
            top_k=50,  # Use top-k sampling
            top_p=0.95,  # Use nucleus sampling
            temperature=0.7  # Control creativity of the output
        )

    # Decode and return output
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return jsonify({"generated_text": generated_text})
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

Request:

curl -X POST http://localhost:5000/generate -H "Content-Type: application/json" -d '{"text": "What time is it in japan"}'

Response:

{"generated_text":"What time is it in japanese?\n\n"}

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744394953a4572093.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信