博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
关于Tfrecord
阅读量:4677 次
发布时间:2019-06-09

本文共 4357 字,大约阅读时间需要 14 分钟。

写入Tfrecord

print("convert data into tfrecord:train\n")        out_file_train = "/home/huadong.wang/bo.yan/fudan_mtl/data/ace2005/bn_nw.train.tfrecord"        writer = tf.python_io.TFRecordWriter(out_file_train)        for i in tqdm(range(len(data_train))):            record = tf.train.Example(features=tf.train.Features(feature={                'word_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_x[i].tostring()])),                'et_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et1[i].tostring()])),                'et_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et2[i].tostring()])),                'position_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),                'position_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),                'chunks': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_chunks[i].tostring()])),                'spath_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_spath[i].tostring()])),                'seq_len': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_x_len[i]])),                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.argmax(train_relation[i])])),                'task': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.int64(0)]))            }))            writer.write(record.SerializeToString())        writer.close()

  

解析tfrecord

def _parse_tfexample(serialized_example):  '''parse serialized tf.train.SequenceExample to tensors  context features : label, task  sequence features: sentence  '''  context_features={'label'    : tf.FixedLenFeature([], tf.int64),                    'task'    : tf.FixedLenFeature([], tf.int64),                    'seq_len': tf.FixedLenFeature([], tf.int64)}  sequence_features={'word_ids': tf.FixedLenSequenceFeature([], tf.int64),                     'et_ids1': tf.FixedLenSequenceFeature([], tf.int64),                     'et_ids2': tf.FixedLenSequenceFeature([], tf.int64),                     'position_ids1': tf.FixedLenSequenceFeature([], tf.int64),                     'position_ids2': tf.FixedLenSequenceFeature([], tf.int64),                     'chunks': tf.FixedLenSequenceFeature([], tf.int64),                     'spath_ids': tf.FixedLenSequenceFeature([], tf.int64),                     }  context_dict, sequence_dict = tf.parse_single_sequence_example(                      serialized_example,                      context_features   = context_features,                      sequence_features  = sequence_features)  sentence = (sequence_dict['word_ids'],sequence_dict['et_ids1'],sequence_dict['et_ids2'],sequence_dict['position_ids1'],              sequence_dict['position_ids2'],sequence_dict['chunks'],sequence_dict['spath_ids'], context_dict['seq_len'])  label = context_dict['label']  task = context_dict['task']  return task, label, sentencedef read_tfrecord(epoch, batch_size):  for dataset in DATASETS:    train_record_file = os.path.join(OUT_DIR, dataset+'.train.tfrecord')    test_record_file = os.path.join(OUT_DIR, dataset+'.test.tfrecord')    train_data = util.read_tfrecord(train_record_file,                                     epoch,                                     batch_size,                                     _parse_tfexample,                                     shuffle=True)    test_data = util.read_tfrecord(test_record_file,                                     epoch,                                   batch_size,                                    _parse_tfexample,                                     shuffle=False)    yield train_data, test_data

 

模型中使用:

def build_task_graph(self, data):    task_label, labels, sentence = data    # sentence = tf.nn.embedding_lookup(self.word_embed, sentence)##########################    word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence    # sentence = word_ids#########################    self.word_ids = word_ids    self.position_ids1 = position_ids1    self.position_ids2 = position_ids2    self.et_ids1 = et_ids1    self.et_ids2 = et_ids2    self.chunks_ids = chunks    self.spath_ids = spath_ids    self.seq_len = seq_len    sentence = self.add_embedding_layers()

  

 

 

转载于:https://www.cnblogs.com/huadongw/p/11483730.html

你可能感兴趣的文章
Qt做的简易图片浏览
查看>>
[开发技巧]·pandas如何保存numpy元素
查看>>
leetcode-17-电话号码的字母组合’
查看>>
Flume 示例
查看>>
Designing for Performance
查看>>
HTML属性的应用
查看>>
HEAP CORRUPTION DETECTED
查看>>
Android URI简单介绍
查看>>
蒙板 模态对话框
查看>>
pythong中的全局变量的调用和嵌套函数中变量的使用
查看>>
【POJ - 3009】Curling 2.0 (dfs+回溯)
查看>>
Windows下载安装良心教程
查看>>
浅析商业银行“业务连续性管理体系”的构建
查看>>
【分享】从《水浒传》中反思什么是真正的执行力
查看>>
java中的static
查看>>
5.侧边栏逻辑
查看>>
评论博客
查看>>
用户代理字符串识别工具源码与slf4j日志使用
查看>>
算法导论第6部分图算法,第22章图的基本算法
查看>>
提示框第三方库之MBProgressHUD
查看>>