使用自定义模块创建数据模型17 Mar 2025 | 阅读 2 分钟 还有另一种方法可以找到预测。在上一节中,我们使用 forward () 并通过实现线性模型来找到预测。此方法非常高效且可靠。它易于理解和实现。 在自定义模块中,我们使用类创建一个自定义模块,以及它的 init() 和 forward() 方法和模型。 init() 方法用于初始化类的新的实例。在此 init() 方法中,第一个参数是 self,它表示类的实例,即尚未初始化的对象,并且在 self 之后,我们可以添加其他参数。 下一个参数是初始化线性模型的实例。在上一节中,初始化线性模型需要输入大小和输出大小,等于 1,但在自定义模块中,我们传递输入大小和输出大小变量,而无需传递其默认值。 在这里,需要导入 torch 的 nn 包。在这里,我们使用继承,以便这个子类将利用来自我们基类和模块的代码。 该模块本身通常将充当所有神经网络模块的基类。之后,我们创建一个模型,通过该模型进行预测。 让我们看一个例子,通过创建自定义模块来完成预测 对于单个数据 输出 <torch._C.Generator object at 0x000001B9B6C4E2B0> tensor([0.0739], grad_fn=<AddBackward0>) ![]() 对于多个数据 输出 <torch._C.Generator object at 0x000001B9B6C4E2B0> tensor([[0.0739], [0.5891], [1.1044]], grad_fn=<AddmmBackward>) ![]() 下一主题损失函数 |
我们请求您订阅我们的新闻通讯以获取最新更新。