javascript - Tensorflow JS - converting tensor to JSON and back to tensor - Stack Overflow

I a training a model in batches and am therefore saving its weights into JSON to storesend.I need to n

I a training a model in batches and am therefore saving its weights into JSON to store/send.

I need to now load those back into tensors - is there a proper way to do this?

tensor.data().then(d => JSON.stringify(d));

// returns
{"0":0.000016666666851961054,"1":-0.00019999999494757503,"2":-0.000183333337190561}

I can iterate over this an convert back to an array manually - but feel there maybe something in the API which would do this cleaner?

I a training a model in batches and am therefore saving its weights into JSON to store/send.

I need to now load those back into tensors - is there a proper way to do this?

tensor.data().then(d => JSON.stringify(d));

// returns
{"0":0.000016666666851961054,"1":-0.00019999999494757503,"2":-0.000183333337190561}

I can iterate over this an convert back to an array manually - but feel there maybe something in the API which would do this cleaner?

Share Improve this question edited Aug 9, 2019 at 10:45 edkeveked 18.4k10 gold badges59 silver badges95 bronze badges asked Jul 25, 2019 at 17:44 dendogdendog 3,3685 gold badges33 silver badges70 bronze badges
Add a ment  | 

3 Answers 3

Reset to default 3

There is no need to stringify the result of data(). To save a tensor and restore it later, two things are needed, the data shape and the data flattened array.

s = tensor.shape 
// get the tensor from backend 

saved = {data: await s.data, shape: shape}
retrievedTensor = tf.tensor(saved.data, saved.shape)

The two pieces of information are given when using array or arraySync - the typedarray generated has the same structure as the tensor

saved = await tensor.array()
retrievedTensor = tf.tensor(saved)

This below can solve the issue, because you can export the Weights 'showWeights' in text format to save it in the database, text file ou browser storage for example and after you can apply in your model again with 'setWeightsFromString'.

showWeights() {

    tf.tidy(() => {

        const weights = this.model.getWeights();
        let pesos = '';
        let shapes = '';

        for (let i = 0; i < weights.length; i++) {

            let tensor = weights[i];
            let shape = weights[i].shape;
            let values = tensor.dataSync().slice();

            if (pesos) pesos += ';';
            if (shapes) shapes += ';';
                       
            pesos += values;
            shapes += shape;

        }

        console.log(pesos);  // sValues for setWeightsFromString
        console.log(shapes); // sShapes for setWeightsFromString
        
    });

}

setWeightsFromString(sValues,sShapes) {   
    
   tf.tidy(() => {

        const aValues = sValues.split(';');
        const aShapes = sShapes.split(';');
        const loadedWeights = [];

        for (let i = 0 ; i < aValues.length ; i++) {
            
            const anValues = aValues[i].split(',').map((e) => {return Number(e)});
            const newValues = new Float32Array(anValues);
            const newShapes = aShapes[i].split(',').map((e) => {return Number(e)});

            loadedWeights[i] = tf.tensor(newValues, newShapes);

        }

        this.model.setWeights(loadedWeights);
        
    });
}

This is my code to do this operation.

import { 
    tensor,
    tensor2d,
} from '@tensorflow/tfjs-node'

import { readFile, writeFile } from 'node:fs'

const path2File = './SAVED-TENSOR/obj.json'

//------------------------------------- 2D ------------------------------------
const a = [
    [
        0.9969421029090881,
        9.39412784576416,
        95.00736999511719
    ]
]

const inputs2dT = tensor2d(a)
    
console.log(`@TENSOR  >> `, inputs2dT.dataSync())
// @TENSOR  >>  Float32Array(3) [
//     0.9969421029090881,
//     9.39412784576416,
//     95.00736999511719
// ]

const aa = await inputs2dT.array()
console.log(aa)
// [ [ 0.9969421029090881, 9.39412784576416, 95.00736999511719 ] ]

const aaObj = {
    "tensor": aa
}

writeFile(
    path2File,
    JSON.stringify(aaObj),
    (err) => {
        if (err) throw err

        console.log('@DATA >> Written!')
    }
)

readFile(path2File, (err, rawData) => {
    if (err) throw err
    const obj = JSON.parse(rawData)
    console.log('@DATA >> ', obj.tensor)

    const t = tensor(obj.tensor)
    if (t.constructor.name === 'Tensor') {
        t.print()
    } else {
        console.log('@UNDEFINED >> Tensor')
    }
})

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1745143915a4613557.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信