登录
注册
开源
企业版
高校版
搜索
帮助中心
使用条款
关于我们
开源
企业版
高校版
私有云
模力方舟
AI 队友
登录
注册
代码拉取完成,页面将自动刷新
开源项目
>
其他开源
>
图书/手册/教程
&&
捐赠
捐赠前请先登录
取消
前往登录
扫描微信二维码支付
取消
支付完成
支付提示
将跳转至支付宝完成支付
确定
取消
Watch
不关注
关注所有动态
仅关注版本发行动态
关注但不提醒动态
22
Star
414
Fork
36
飞城
/
ai_tutorial_book
代码
Issues
27
Pull Requests
0
Wiki
统计
流水线
服务
质量分析
Jenkins for Gitee
腾讯云托管
腾讯云 Serverless
悬镜安全
阿里云 SAE
Codeblitz
SBOM
我知道了,不再自动展开
更新失败,请稍后重试!
移除标识
内容风险标识
本任务被
标识为内容中包含有代码安全 Bug 、隐私泄露等敏感信息,仓库外成员不可访问
6.2-MNISTClassifical中的mnist_mobilenetV2_alone.py代码报错
待办的
#IAN7YE
王淇2023211936
创建于
2024-08-29 08:33
### 重现步骤 import torch import torch.optim as optim from torchvision import datasets, transforms from torchvision.models.mobilenet import mobilenet_v2 from torch.optim.lr_scheduler import StepLR from torch.nn import CrossEntropyLoss import pandas as pd def train(model, device, train_loader, optimizer, epoch): log_interval = 10 loss_func = CrossEntropyLoss() model.train() for batch_idx, (data, target) in enumerate(train_loader): data = data.repeat(1, 3, 1, 1) data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = loss_func(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def tst(model, device, test_loader): model.eval() test_loss = 0 correct = 0 loss_func = CrossEntropyLoss() with torch.no_grad(): for data, target in test_loader: data = data.repeat(1, 3, 1, 1) data, target = data.to(device), target.to(device) output = model(data) test_loss += loss_func(output, target) pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) def main(): batch_size = 1000 learning_rate = 1.0 reduce_lr_gamma = 0.7 epochs = 4 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('Device: {} Epochs: {} Batch size: {}'.format(device, epochs, batch_size)) kwargs = {'batch_size': batch_size} if torch.cuda.is_available(): kwargs.update({'num_workers': 1, 'pin_memory': True}) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform) dataset2 = datasets.MNIST('./data', train=False, transform=transform) print('Length train: {} Length test: {}'.format(len(dataset1), len(dataset2))) train_loader = torch.utils.data.DataLoader(dataset1, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, shuffle=False, **kwargs) print('Number of train batches: {} Number of test batches: {}'.format(len(train_loader), len(test_loader))) model = mobilenet_v2(pretrained=True) model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=10) model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=learning_rate) scheduler = StepLR(optimizer, step_size=1, gamma=reduce_lr_gamma) for epoch in range(1, epochs + 1): train(model, device, train_loader, optimizer, epoch) tst(model, device, test_loader) scheduler.step() torch.save(model.state_dict(), "mnist_mobilenet.pt") # Final prediction ids = list(range(len(dataset2))) submission = pd.DataFrame(ids, columns=['id']) predictions = [] real = [] for data, target in test_loader: data = data.repeat(1, 3, 1, 1) data = data.to(device) output = model(data) pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability predictions += list(pred.cpu().numpy()[:, 0]) real += list(target.numpy()) submission['pred'] = predictions submission['real'] = real submission.to_csv('submission.csv', index=False) print('Submission saved in: {}'.format('submission.csv')) if __name__ == '__main__': main() ### 报错信息 在 '__init__.py' 中找不到引用 'data'(70,71行)
### 重现步骤 import torch import torch.optim as optim from torchvision import datasets, transforms from torchvision.models.mobilenet import mobilenet_v2 from torch.optim.lr_scheduler import StepLR from torch.nn import CrossEntropyLoss import pandas as pd def train(model, device, train_loader, optimizer, epoch): log_interval = 10 loss_func = CrossEntropyLoss() model.train() for batch_idx, (data, target) in enumerate(train_loader): data = data.repeat(1, 3, 1, 1) data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = loss_func(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def tst(model, device, test_loader): model.eval() test_loss = 0 correct = 0 loss_func = CrossEntropyLoss() with torch.no_grad(): for data, target in test_loader: data = data.repeat(1, 3, 1, 1) data, target = data.to(device), target.to(device) output = model(data) test_loss += loss_func(output, target) pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) def main(): batch_size = 1000 learning_rate = 1.0 reduce_lr_gamma = 0.7 epochs = 4 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('Device: {} Epochs: {} Batch size: {}'.format(device, epochs, batch_size)) kwargs = {'batch_size': batch_size} if torch.cuda.is_available(): kwargs.update({'num_workers': 1, 'pin_memory': True}) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform) dataset2 = datasets.MNIST('./data', train=False, transform=transform) print('Length train: {} Length test: {}'.format(len(dataset1), len(dataset2))) train_loader = torch.utils.data.DataLoader(dataset1, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, shuffle=False, **kwargs) print('Number of train batches: {} Number of test batches: {}'.format(len(train_loader), len(test_loader))) model = mobilenet_v2(pretrained=True) model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=10) model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=learning_rate) scheduler = StepLR(optimizer, step_size=1, gamma=reduce_lr_gamma) for epoch in range(1, epochs + 1): train(model, device, train_loader, optimizer, epoch) tst(model, device, test_loader) scheduler.step() torch.save(model.state_dict(), "mnist_mobilenet.pt") # Final prediction ids = list(range(len(dataset2))) submission = pd.DataFrame(ids, columns=['id']) predictions = [] real = [] for data, target in test_loader: data = data.repeat(1, 3, 1, 1) data = data.to(device) output = model(data) pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability predictions += list(pred.cpu().numpy()[:, 0]) real += list(target.numpy()) submission['pred'] = predictions submission['real'] = real submission.to_csv('submission.csv', index=False) print('Submission saved in: {}'.format('submission.csv')) if __name__ == '__main__': main() ### 报错信息 在 '__init__.py' 中找不到引用 'data'(70,71行)
评论 (
1
)
登录
后才可以发表评论
状态
待办的
待办的
进行中
已完成
已关闭
负责人
未设置
标签
未设置
标签管理
里程碑
未关联里程碑
未关联里程碑
Pull Requests
未关联
未关联
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
未关联
未关联
master
开始日期   -   截止日期
-
置顶选项
不置顶
置顶等级:高
置顶等级:中
置顶等级:低
优先级
不指定
严重
主要
次要
不重要
参与者(1)
Python
1
https://gitee.com/flycity/ai_tutorial_book.git
git@gitee.com:flycity/ai_tutorial_book.git
flycity
ai_tutorial_book
ai_tutorial_book
点此查找更多帮助
搜索帮助
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
评论
仓库举报
回到顶部
登录提示
该操作需登录 Gitee 帐号,请先登录后再操作。
立即登录
没有帐号,去注册