简单迁移学习inception-v3各种图像的识别

本文通过使用预训练的Inception-v3模型进行迁移学习,实现对图像的分类任务。介绍了如何加载模型、解析类别标签,并展示了如何利用TensorFlow进行预测。
部署运行你感兴趣的模型镜像

   接着上一篇文章,上一篇文章中,我们下载了inception-v3的模型,接下来,我们进行简单的迁移学习,将物体进行分类。

下面是代码:

# coding: UTF-8
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt


class NodeLookup(object):
    def __init__(self):
        label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
        uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
        self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

    def load(self, label_lookup_path, uid_lookup_path):
        # 加载分类字符串n***********对应分类名称的文件
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
        uid_to_human = {}
        for line in proto_as_ascii_lines:
            line = line.strip('\n')
            parsed_items = line.split('\t')
            uid = parsed_items[0]  # n15092227
            human_string = parsed_items[1]
            uid_to_human[uid] = human_string

        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
        node_id_to_uid = {}
        for line in proto_as_ascii:
            if line.startswith('  target_class:'):
                target_class = int(line.split(': ')[1])
            if line.startswith('  target_class_string:'):
                target_class_string = line.split(': ')[1]
                node_id_to_uid[target_class] = target_class_string[1: -2]

        node_id_to_name = {}
        for key, val in node_id_to_uid.items():
            name = uid_to_human[val]
            node_id_to_name[key] = name

        return node_id_to_name

    def id_to_string(self, node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]


# 创建一个图来存放google调整好的模型
with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    # 遍历目录
    for root, dirs, files in os.walk('images/'):
        for file in files:
            image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
            predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
            predictions = np.squeeze(predictions)

            image_path = os.path.join(root, file)
            print image_path

            img = Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            top_k = predictions.argsort()[-5:][::-1]
            node_lookup = NodeLookup()
            for node_id in top_k:
                human_string = node_lookup.id_to_string(node_id)
                score = predictions[node_id]
                print('%s (score=%.5f)' % (human_string, score))
            print()

其中代码NodeLookup是类,作用是将inception-v3中1000个分类得分转换成类别的字符串,其中的文件如下所示:

imagenet_2012_challenge_label_map_proto.pbtxt

imagenet_synset_to_human_label_map.txt

从上面的图中可以很简单的看出来。

然后创建一个图用来存放调整好的google存放的权重的图。

with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
然后创建会话,在会话中引进softmax的分类器,用softmax分类器进行预测predictions,其中,得到的predictions是二维的,

predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})

,将图片jped的格式传入。然后np.squeeze得到一维。

最后得到预测值。




您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

06-22
### 得物技术栈及开发者文档分析 得物作为一家专注于潮流商品的电商平台,其技术栈和开发者文档主要围绕电商平台的核心需求展开。以下是对得物技术栈及相关开发资源的详细解析: #### 1. 技术栈概述 得物的技术栈通常会涵盖前端、后端、移动应用开发以及大数据处理等多个领域。以下是可能涉及的主要技术栈[^3]: - **前端开发**: 前端技术栈可能包括现代框架如 React 或 Vue.js,用于构建高效、响应式的用户界面。此外,还会使用 Webpack 等工具进行模块化打包和优化。 - **后端开发**: 后端技术栈可能采用 Java Spring Boot 或 Node.js,以支持高并发和分布式架构。数据库方面,MySQL 和 Redis 是常见的选择,分别用于关系型数据存储和缓存管理。 - **移动应用开发**: 得物的移动应用开发可能基于原生技术(如 Swift/Kotlin)或跨平台框架(如 Flutter)。这有助于确保移动端应用的性能和用户体验一致性。 - **大数据云计算**: 在大数据处理方面,得物可能会使用 Hadoop 或 Spark 进行数据挖掘和分析。同时,依托云服务提供商(如阿里云或腾讯云),实现弹性扩展和资源优化。 #### 2. 开发者文档分析 类似于引用中提到的 Adobe 开发者文档模板[^2],得物也可能提供一套完整的开发者文档体系,以支持内部团队协作和外部开发者接入。以下是开发者文档可能包含的内容: - **API 文档**: 提供 RESTful API 或 GraphQL 的详细说明,帮助开发者快速集成得物的功能模块,例如商品搜索、订单管理等。 - **SDK 集成指南**: 针对不同平台(如 iOS、Android 或 Web)提供 SDK 下载和集成教程,简化第三方应用的开发流程。 - **技术博客**: 分享得物在技术实践中的经验成果,例如如何优化图片加载速度、提升应用性能等。 - **开源项目**: 得物可能将部分技术成果开源,供社区开发者学习和贡献。这不仅有助于提升品牌形象,还能吸引更多优秀人才加入。 #### 3. 示例代码 以下是一个简单的示例代码,展示如何通过 RESTful API 调用得物的商品搜索功能(假设接口已存在): ```python import requests def search_items(keyword, page=1): url = "https://api.dewu.com/v1/items/search" headers = { "Authorization": "Bearer YOUR_ACCESS_TOKEN", "Content-Type": "application/json" } params = { "keyword": keyword, "page": page, "size": 10 } response = requests.get(url, headers=headers, params=params) if response.status_code == 200: return response.json() else: return {"error": "Failed to fetch data"} # 调用示例 result = search_items("Air Jordan", page=1) print(result) ``` 此代码片段展示了如何通过 Python 请求得物的 API,并获取指定关键词的商品列表。 --- ###
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值