FPGA - lenet5 - softmax

192 阅读2分钟

Softmax activation

image.png

数据位宽为 32,10分类。

softmax函数 ,输入归一化,求得各类概率。

1 指数运算

2 计算指数和

3 求指数和倒数

4 计算每个元素的softmax值

10分类有10个指数模块。

将多个输入输入到各自的指数模块 即exponet模块来求指数。 然后通过加法器求所有指数的和,再由floatReciprocal来计算指数和的倒数,最后通过乘法器来计算各输入的softmax值并输出。

image.png

指数模块 exponent

image.png

x - e^x

指数函数,求解e^x

逻辑: 用泰勒展开拟合,包含两个乘法器和一个加法器。

时序电路

迭代七次: image.png

需要八个周期去处理,七个周期计算,一个周期输出。

image.png

10个指数模块,一个乘法器,一个加法器,floatReciprocal计数指数和的倒数。 softmax.v顶层模块


module softmax(inputs,clk,enable,outputs,ackSoft);
parameter DATA_WIDTH=32;
localparam inputNum=10;      // 10分类   
input [DATA_WIDTH*inputNum-1:0] inputs;
input clk;
input enable;
output reg [DATA_WIDTH*inputNum-1:0] outputs;
output reg ackSoft;  // softmax的应答信号,表示计算完毕

wire [DATA_WIDTH-1:0] expSum;
wire [DATA_WIDTH-1:0] expReciprocal;  // 指数之和的倒数
wire [DATA_WIDTH-1:0] outMul;
wire [DATA_WIDTH*inputNum-1:0] exponents ; //所有指数模块的输出整合
wire [inputNum-1:0] acksExp; //acknowledge signals of exponents   指数模块的应答信号
wire ackDiv; //ack signal of the division unit倒数模块的应答信号

reg enableDiv; //signal to enable division unit initially zero
reg [DATA_WIDTH-1:0] outExpReg;
reg [3:0] mulCounter;
reg [3:0] addCounter;






// 一个输入对应一个指数模块
// generate 可以被综合  例化10个指数模块
genvar i;
generate
	for (i = 0; i < inputNum; i = i + 1) begin
		exponent #(.DATA_WIDTH(DATA_WIDTH)) exp (
		.x(inputs[DATA_WIDTH*i+:DATA_WIDTH]),
		.enable(enable),
		.clk(clk),
		.output_exp(exponents[DATA_WIDTH*i+:DATA_WIDTH]),
		.ack(acksExp[i]));
	end
endgenerate



// 计算所有指数模块的和
floatAdd FADD1 (exponents[DATA_WIDTH*addCounter+:DATA_WIDTH],outExpReg,expSum);

// 计算和的倒数
floatReciprocal #(.DATA_WIDTH(DATA_WIDTH)) FR (.number(expSum),.clk(clk),.output_rec(expReciprocal),.ack(ackDiv),.enable(enableDiv));

// 计算softmax值,计算每个种类的分类概率
floatMult FM1 (exponents[DATA_WIDTH*mulCounter+:DATA_WIDTH],expReciprocal,outMul); //multiplication with reciprocal



always @ (negedge clk) begin
	if(enable==1'b1) begin
		if(ackSoft==1'b0) begin 
                  //等待指数模块计算完毕
			if(acksExp[0]==1'b1) begin //if the exponents finished         
                        //指数总和  的倒数还没开始计算,则求和得到expSum
				if(enableDiv==1'b0) begin //division still did not start
					if(addCounter<4'b1001) begin
						addCounter=addCounter+1;
						outExpReg=expSum;
					end
                                        
                                        // expSum 计算好了,开始算倒数
					else begin
						enableDiv=1'b1;
					end
				end
                                // 倒数也计算好了,开始计算每个类别的概率
				else if(ackDiv==1'b1) begin //check if the reciprocal is ready
					if(mulCounter<4'b1010) begin
						outputs[DATA_WIDTH*mulCounter+:DATA_WIDTH]=outMul;
						mulCounter=mulCounter+1;
					end
					else begin
						ackSoft=1'b1;
					end
				end
			end
		end
	end
	else begin
                //未使能则保持复位状态
		//if enable is off reset all counters and acks
		mulCounter=4'b0000;
		addCounter=4'b0000;
		outExpReg=32'b00000000000000000000000000000000;
		ackSoft=1'b0;
		enableDiv=1'b0;
	end
	
end

endmodule


 

指数模块

指数函数,通过泰勒展开实现

处理速度: 包括复位周期在内的八个时钟周期,7个周期计算,1周期输出结果

module exponent (x,clk,enable,output_exp,ack);
parameter DATA_WIDTH=32;
localparam taylor_iter=7;   // 迭代次数
input [DATA_WIDTH-1:0] x;
input clk;
input enable;
output reg ack;
output reg [DATA_WIDTH-1:0] output_exp;

reg [DATA_WIDTH*taylor_iter-1:0] divisors; // 1/6 1/5 1/4 1/3 1/2 1 1
reg [DATA_WIDTH-1:0] mult1; //第一周期为1,之后的为乘法器2的输出,用来迭代
reg [DATA_WIDTH-1:0] one_or_x; //输入,第一周期为1,其余时刻为x
wire [DATA_WIDTH-1:0] out_m1; //output of the first multiplication which is either with 1 or x
wire [DATA_WIDTH-1:0] out_m2; //the output of the second muliplication and the input of the first
wire [DATA_WIDTH-1:0] output_add1;
reg [DATA_WIDTH-1:0] out_reg; //缓存加法器输出结果

floatMult FM1 (mult1,one_or_x,out_m1); 
floatMult FM2 (out_m1,divisors[31:0],out_m2); 
floatAdd FADD1 (out_m2,out_reg,output_add1); 

always @ (posedge clk) begin
    if(enable==1'b0) begin
        one_or_x=32'b00111111100000000000000000000000; //initially 1
        mult1=32'b00111111100000000000000000000000; //initially 1
        out_reg=32'b00000000000000000000000000000000; //initially 0
        output_exp=32'b00000000000000000000000000000000; //output zero until ack is 1
        divisors=224'b00111110001010101010101010101011_00111110010011001100110011001101_00111110100000000000000000000000_00111110101010101010101010101011_00111111000000000000000000000000_00111111100000000000000000000000_00111111100000000000000000000000;
        ack=1'b0; // acknowledge is 0 at the beginning
    end 
    else begin
	   one_or_x=x;
	   mult1=out_m2; //获得乘法器2的输出,用来迭代
	   divisors=divisors>>32; //右移32位,把已经乘过的分数去掉
	   out_reg=output_add1;
// 迭代结束	   if(divisors==224'b00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000) 
	   begin
		  output_exp=output_add1;
		  ack=1'b1;
	   end
	  end
end

endmodule