實戰 | 基於深度學習模型VGG的影象識別(附程式碼)
阿新 • • 發佈:2019-02-16
def train():
data_dim = 3 * 32 * 32
class_dim = 10
image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(data_dim))
net = vgg_bn_drop(image)
out = paddle.layer.fc(input=net,
size=class_dim,
act=paddle.activation.Softmax())
lbl = paddle.layer.data(
name="label" , type=paddle.data_type.integer_value(class_dim))
cost = paddle.layer.classification_cost(input=out, label=lbl)
parameters = paddle.parameters.create(cost)
print(parameters.keys())
momentum_optimizer = paddle.optimizer.Momentum(
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128),
learning_rate=0.1 / 128.0,
learning_rate_decay_a=0.1,
learning_rate_decay_b=50000 * 100,
learning_rate_schedule='discexp')
# Create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=momentum_optimizer)
reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000 ),
batch_size=128)
feeding = {'image': 0,
'label': 1}
trainer.train(
reader=reader,
num_passes=200,
event_handler=event_handler,
feeding=feeding)
data_dim = 3 * 32 * 32
class_dim = 10
image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(data_dim))
net = vgg_bn_drop(image)
out = paddle.layer.fc(input=net,
size=class_dim,
act=paddle.activation.Softmax())
lbl = paddle.layer.data(
name="label"
cost = paddle.layer.classification_cost(input=out, label=lbl)
parameters = paddle.parameters.create(cost)
print(parameters.keys())
momentum_optimizer = paddle.optimizer.Momentum(
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0002
learning_rate=0.1 / 128.0,
learning_rate_decay_a=0.1,
learning_rate_decay_b=50000 * 100,
learning_rate_schedule='discexp')
# Create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=momentum_optimizer)
reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000
batch_size=128)
feeding = {'image': 0,
'label': 1}
trainer.train(
reader=reader,
num_passes=200,
event_handler=event_handler,
feeding=feeding)