博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
BankNote
阅读量:6354 次
发布时间:2019-06-22

本文共 2303 字,大约阅读时间需要 7 分钟。

1 # coding=utf-8 2 import pandas as pd 3 import numpy as np 4 from sklearn import cross_validation 5 import tensorflow as tf 6  7 global flag 8 flag=0 9 10 def DataPreprocessing():11     abalone = pd.read_csv("ceshi.csv", sep=',', header=0, keep_default_na=True,na_values=[])12     X_train=np.array(abalone.iloc[:,:4])13     Y_train=np.array(abalone.iloc[:,4:])14     # Y_train=[]15     # for i in range(len(X_train)):16     #     if X_train[i][0] == 'M':17     #         X_train[i][0]=018     #     elif X_train[i][0]=='F':19     #         X_train[i][0]=120     #     else:21     #         X_train[i][0]=222     #23     # for i in range(len(Y_train_)):24     #25     #     #print(Y_train[i][0])26     #     Y_train.append(Y_train_[i][0])27 28     # print(X_train)29     # print(len(X_train))30     # print(Y_train)31     # print(len(Y_train))32    # print(min(Y_train))33    # print(max(Y_train))34 35     return cross_validation.train_test_split(X_train,Y_train,test_size=0.25,random_state=0,stratify=Y_train)36 37 38 def GetInputs():39     global flag40     X_train, X_test, Y_train, Y_test = DataPreprocessing()41 42     #print(X_train)43     # print(len(X_test))44     # print(len(Y_train))45     # print(len(Y_test))46 47 48     #X_train[X_train.isnull().any(axis=1)]49     #X_train.fillna('',inplace=True)50 51     print(X_train)52     print(Y_test)53 54     x_train=tf.constant(X_train)55     y_train=tf.constant(Y_train)56     x_test=tf.constant(X_test)57     y_test=tf.constant(Y_test)58 59     print(x_train)60     print(y_train)61     print(x_test)62     print(y_test)63 64     if flag==0:65         return x_train,y_train66     else:67         return x_test,y_test68 69 70 def Main():71 72     global flag73 74     feature_columns=[tf.contrib.layers.real_valued_column("",dimension=4)]75 76     clf=tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10,20,10],n_classes=2,model_dir="/home/jiangjing/TensorflowModel/banknote")77 78     clf.fit(input_fn=GetInputs,steps=2000)79 80     flag=181     accuracy_score=clf.evaluate(input_fn=GetInputs,steps=1)["accuracy"]82 83     print("nTest Accuracy:{0:f}".format(accuracy_score))84 85 if __name__ =="__main__":86     #DataPreprocessing()87 88     Main()89 90 exit(0)

 

转载于:https://www.cnblogs.com/acm-jing/p/9097373.html

你可能感兴趣的文章
nasm预处理器(2)
查看>>
二叉排序树 算法实验
查看>>
Silverlight 5 beta新特性探索系列:10.浏览器模式下内嵌HTML+浏览器模式下创建txt文本文件...
查看>>
YourSQLDba 配置——修改备份路径
查看>>
nginx web服务理论与实战
查看>>
java 库存 进销存 商户 多用户管理系统 SSM springmvc 项目源码
查看>>
网易音乐版轮播-react组件版本
查看>>
ES6 - 函数与剩余运算符
查看>>
你对position了解有多深?看完这2道有意思的题你就有底了...
查看>>
WebSocket跨域问题解决
查看>>
世界经济论坛发布关于区块链网络安全的报告
查看>>
巨杉数据库加入CNCF云原生应用计算基金会,共建开源技术生态
查看>>
Ubuntu 16.04安装Nginx
查看>>
从 JS 编译原理到作用域(链)及闭包
查看>>
flutter 教程(一)flutter介绍
查看>>
CSS面试题目及答案
查看>>
【从蛋壳到满天飞】JS 数据结构解析和算法实现-Arrays(数组)
查看>>
Spring自定义注解从入门到精通
查看>>
笔记本触摸板滑动事件导致连滑的解决方式
查看>>
Runtime 学习:消息传递
查看>>