MatchingNetworkstensorflow1229Word下载.docx
- 文档编号:21628370
- 上传时间:2023-01-31
- 格式:DOCX
- 页数:21
- 大小:21.83KB
MatchingNetworkstensorflow1229Word下载.docx
《MatchingNetworkstensorflow1229Word下载.docx》由会员分享,可在线阅读,更多相关《MatchingNetworkstensorflow1229Word下载.docx(21页珍藏版)》请在冰豆网上搜索。
#Experimentbuilder
data=dataset.OmniglotNShotDataset(batch_size=batch_size,
classes_per_set=classes_per_set,samples_per_class=samples_per_class)
experiment=ExperimentBuilder(data)
one_shot_omniglot,losses,c_error_opt_op,init=experiment.build_experiment(batch_size,
classes_per_set,
samples_per_class,fce)
total_epochs=300
total_train_batches=1000
total_val_batches=250
total_test_batches=250
save_statistics(experiment_name,["
epoch"
"
train_c_loss"
train_c_accuracy"
val_loss"
val_accuracy"
"
test_c_loss"
test_c_accuracy"
])
#Experimentinitializationandrunning
withtf.Session()assess:
sess.run(init)
saver=tf.train.Saver()
ifcontinue_from_epoch!
=-1:
#loadcheckpointifneeded
checkpoint="
saved_models/{}_{}.ckpt"
.format(experiment_name,continue_from_epoch)
variables_to_restore=[]
forvarintf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
print(var)
variables_to_restore.append(var)
tf.logging.info('
Fine-tuningfrom%s'
%checkpoint)
fine_tune=slim.assign_from_checkpoint_fn(
checkpoint,
variables_to_restore,
ignore_missing_vars=True)
fine_tune(sess)
best_val=0.
withtqdm.tqdm(total=total_epochs)aspbar_e:
foreinrange(0,total_epochs):
total_c_loss,total_accuracy=experiment.run_training_epoch(total_train_batches=total_train_batches,
sess=sess)
print("
Epoch{}:
train_loss:
{},train_accuracy:
{}"
.format(e,total_c_loss,total_accuracy))
total_val_c_loss,total_val_accuracy=experiment.run_validation_epoch(total_val_batches=total_val_batches,
val_loss:
{},val_accuracy:
.format(e,total_val_c_loss,total_val_accuracy))
iftotal_val_accuracy>
=best_val:
#ifnewbestvalaccuracy->
produceteststatistics
best_val=total_val_accuracy
total_test_c_loss,total_test_accuracy=experiment.run_testing_epoch(
total_test_batches=total_test_batches,sess=sess)
test_loss:
{},test_accuracy:
.format(e,total_test_c_loss,total_test_accuracy))
else:
total_test_c_loss=-1
total_test_accuracy=-1
save_statistics(experiment_name,
[e,total_c_loss,total_accuracy,total_val_c_loss,total_val_accuracy,total_test_c_loss,
total_test_accuracy])
save_path=saver.save(sess,"
.format(experiment_name,e))
pbar_e.update
(1)
2.data
importnumpyasnp
fromscipy.ndimageimportrotate
classOmniglotNShotDataset():
def__init__(self,batch_size,classes_per_set=10,samples_per_class=1,seed=2591,shuffle_classes=True):
"
ConstructsanN-ShotomniglotDataset
:
parambatch_size:
Experimentbatch_size
paramclasses_per_set:
Integerindicatingthenumberofclassesperset
paramsamples_per_class:
Integerindicatingsamplesperclass
e.g.Fora20-way,1-shotlearningtask,useclasses_per_set=20andsamples_per_class=1
Fora5-way,10-shotlearningtask,useclasses_per_set=5andsamples_per_class=10
np.random.seed(seed)
self.x=np.load("
data.npy"
)
self.x=np.reshape(self.x,newshape=(1622,20,28,28,1))
ifshuffle_classes:
class_ids=np.arange(self.x.shape[0])
np.random.shuffle(class_ids)
self.x=self.x[class_ids]
self.x_train,self.x_test,self.x_val=self.x[:
1200],self.x[1200:
1411],self.x[1411:
]
self.mean=np.mean(list(self.x_train)+list(self.x_val))
self.std=np.std(list(self.x_train)+list(self.x_val))
self.batch_size=batch_size
self.n_classes=self.x.shape[0]
self.classes_per_set=classes_per_set
self.samples_per_class=samples_per_class
train_shape"
self.x_train.shape,"
test_shape"
self.x_test.shape,"
val_shape"
self.x_val.shape)
self.indexes={"
train"
:
0,"
val"
test"
0}
self.datasets={"
self.x_train,"
self.x_val,"
self.x_test}#originaldatacached
defpreprocess_batch(self,x_batch):
Normalizesourdata,tohaveameanof0andsdof1
x_batch=(x_batch-self.mean)/self.std
returnx_batch
defsample_new_batch(self,data_pack):
Collects1000batchesdataforN-shotlearning
paramdata_pack:
Datapacktouse(anyoneoftrain,val,test)
return:
Alistwith[support_set_x,support_set_y,target_x,target_y]readytobefedtoournetworks
support_set_x=np.zeros((self.batch_size,self.classes_per_set,self.samples_per_class,data_pack.shape[2],
data_pack.shape[3],data_pack.shape[4]),dtype=np.float32)
support_set_y=np.zeros((self.batch_size,self.classes_per_set,self.samples_per_class),dtype=np.float32)
target_x=np.zeros((self.batch_size,data_pack.shape[2],data_pack.shape[3],data_pack.shape[4]),
dtype=np.float32)
target_y=np.zeros((self.batch_size,),dtype=np.float32)
foriinrange(self.batch_size):
classes_idx=np.arange(data_pack.shape[0])
samples_idx=np.arange(data_pack.shape[1])
choose_classes=np.random.choice(classes_idx,size=self.classes_per_set,replace=False)
choose_label=np.random.choice(self.classes_per_set,size=1)
choose_samples=np.random.choice(samples_idx,size=self.samples_per_class+1,replace=False)
x_temp=data_pack[choose_classes]
x_temp=x_temp[:
choose_samples]
y_temp=np.arange(self.classes_per_set)
support_set_x[i]=x_temp[:
:
-1]
support_set_y[i]=np.expand_dims(y_temp[:
],axis=1)
target_x[i]=x_temp[choose_label,-1]
target_y[i]=y_temp[choose_label]
returnsupport_set_x,support_set_y,target_x,target_y
defget_batch(self,dataset_name,augment=False):
Getsnextbatchfromthedatasetwithname.
paramdataset_name:
Thenameofthedataset(oneof"
x_support_set,y_support_set,x_target,y_target=self.sample_new_batch(self.datasets[dataset_name])
ifaugment:
k=np.random.randint(0,4,size=(self.batch_size,self.classes_per_set))
x_augmented_support_set=[]
x_augmented_target_set=[]
forbinrange(self.batch_size):
temp_class_support=[]
forcinrange(self.classes_per_set):
x_temp_support_set=self.rotate_batch(x_support_set[b,c],axis=(1,2),k=k[b,c])
ify_target[b]==y_support_set[b,c,0]:
x_temp_target=self.rotate_batch(x_target[b],axis=(0,1),k=k[b,c])
temp_class_support.append(x_temp_support_set)
x_augmented_support_set.append(temp_class_support)
x_augmented_target_set.append(x_temp_target)
x_support_set=np.array(x_augmented_support_set)
x_target=np.array(x_augmented_target_set)
x_support_set=self.preprocess_batch(x_support_set)
x_target=self.preprocess_batch(x_target)
returnx_support_set,y_support_set,x_target,y_target
defrotate_batch(self,x_batch,axis,k):
x_batch=rotate(x_batch,k*90,reshape=False,axes=axis,mode="
nearest"
defget_train_batch(self,augment=False):
Getnexttrainingbatch
Nexttrainingbatch
returnself.get_batch("
augment)
defget_test_batch(self,augment=False):
Getnexttestbatch
Nexttest_batch
defget_val_batch(self,augment=False):
Getnextvalbatch
Nextvalbatch
3.experiment_builder
importtensorflowastf
fromone_shot_learning_networkimportMatchingNetwork
classExperimentBuilder:
def__init__(self,data):
InitializesanExperimentBuilderobject.TheExperimentBuilderobjecttakescareofsettingupourexperiment
andprovideshelperfunctionssuchasrun_training_epochandrun_validation_epochtosimplifyouttraining
andevaluationprocedures.
paramdata:
Adataproviderclass
self.data=data
defbuild_experiment(self,batch_size,classes_per_set,samples_per_class,fce):
Theexperimentbatchsize
Anintegerindicatingthenumberofclassespersupportset
Anintegerindicatingthenumberofsamplesperclass
paramchannels:
Theimagechannels
paramfce:
Whethertousefullcontextembeddingsornot
amatching_networkobject,alongwiththelosses,thetrainingopsandtheinitop
height,width,channels=self.data.x.shape[2],self.data.x.shape[3],self.data.x.shape[4]
self.support_set_images=tf.placeholder(tf.float32,[batch_size,classes_per_set,samples_per_class,height,width,
channels],'
support_set_images'
self.support_set_labels=tf.placeholder(tf.int32,[batch_size,classes_per_set,samples_per_class],'
support_set_labels'
self.target_image=tf.placeholder(tf.float32,[batch_size,height,width,channels],'
target_image'
self.target_label=tf.placeholder(tf.int32,[batch_size],'
target_label'
self.training_phase=tf.placeholder(tf.bool,name='
training-flag'
self.rotate_flag=tf.placeholder(tf.bool,name='
rotate-flag'
self.keep_prob=tf.placeholder(tf.float32,name='
dropout-prob'
self.current_learning_rate=1e-03
self.learning_rate=tf.placeholder(tf.float32,name='
learning-rate-set'
self.one_shot_omniglot=MatchingNetwork(batch_size=batch_size,support_set_images=self.support_set_images,
support_set_labels=self.support_set_labels,
target_image=self.target_image,target_label=self.target_label,
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- MatchingNetworkstensorflow1229